small pixel drawing of a pufferfish cascade

internal/agent/dns.go

package agent

import (
	"fmt"
	"net"
	"strings"
	"sync"

	"github.com/hashicorp/serf/serf"
	"github.com/miekg/dns"
)

// DNSServer answers queries for <node>.node.<domain> and
// <svc>.service.<domain>. The domain is configurable via Config.DNSDomain
// and defaults to "consul" so consul SDKs that hardcode .consul resolve
// without changes.
type DNSServer struct {
	agent  *Agent
	domain string // normalized: lowercase, no leading/trailing dot

	udp *dns.Server
	tcp *dns.Server
	wg  sync.WaitGroup
}

func newDNSServer(a *Agent) *DNSServer {
	domain := strings.ToLower(strings.Trim(a.Config.DNSDomain, "."))
	if domain == "" {
		domain = DefaultDNSDomain
	}
	return &DNSServer{agent: a, domain: domain}
}

// Start binds the UDP and TCP DNS listeners. Returns once both are
// listening; serving happens in background goroutines.
func (d *DNSServer) Start() error {
	addr := fmt.Sprintf("%s:%d", d.agent.Config.DNSBindAddr.IP, d.agent.Config.DNSBindAddr.Port)
	handler := dns.HandlerFunc(d.handle)

	udpConn, err := net.ListenPacket("udp", addr)
	if err != nil {
		return fmt.Errorf("dns udp listen %s: %w", addr, err)
	}
	tcpListener, err := net.Listen("tcp", addr)
	if err != nil {
		udpConn.Close()
		return fmt.Errorf("dns tcp listen %s: %w", addr, err)
	}

	d.udp = &dns.Server{PacketConn: udpConn, Handler: handler}
	d.tcp = &dns.Server{Listener: tcpListener, Handler: handler}

	d.wg.Add(2)
	go func() {
		defer d.wg.Done()
		if err := d.udp.ActivateAndServe(); err != nil {
			d.agent.logger.Error("dns udp serve", err)
		}
	}()
	go func() {
		defer d.wg.Done()
		if err := d.tcp.ActivateAndServe(); err != nil {
			d.agent.logger.Error("dns tcp serve", err)
		}
	}()
	return nil
}

// Shutdown stops both listeners. Safe to call multiple times.
func (d *DNSServer) Shutdown() {
	if d.udp != nil {
		_ = d.udp.Shutdown()
	}
	if d.tcp != nil {
		_ = d.tcp.Shutdown()
	}
	d.wg.Wait()
}

// handle is the single dispatch entrypoint for all DNS queries.
func (d *DNSServer) handle(w dns.ResponseWriter, req *dns.Msg) {
	resp := new(dns.Msg)
	resp.SetReply(req)
	resp.Authoritative = true
	resp.RecursionAvailable = false

	if len(req.Question) == 0 {
		resp.SetRcode(req, dns.RcodeFormatError)
		_ = w.WriteMsg(resp)
		return
	}
	q := req.Question[0]

	labels, ok := d.stripDomain(q.Name)
	if !ok {
		resp.SetRcode(req, dns.RcodeRefused)
		_ = w.WriteMsg(resp)
		return
	}

	switch {
	case isService(labels):
		d.answerService(resp, q, labels)
	case isNode(labels):
		d.answerNode(resp, q, labels)
	default:
		resp.SetRcode(req, dns.RcodeNameError) // NXDOMAIN
	}

	_ = w.WriteMsg(resp)
}

// stripDomain returns the labels left of the configured domain. For
// "web.service.consul." with domain "consul", returns ["web", "service"].
// Returns false if the query isn't in our domain.
func (d *DNSServer) stripDomain(qname string) ([]string, bool) {
	name := strings.ToLower(strings.TrimSuffix(qname, "."))
	suffix := "." + d.domain
	if name == d.domain {
		return nil, true
	}
	if !strings.HasSuffix(name, suffix) {
		return nil, false
	}
	trimmed := strings.TrimSuffix(name, suffix)
	if trimmed == "" {
		return nil, true
	}
	return strings.Split(trimmed, "."), true
}

func isService(labels []string) bool {
	return len(labels) >= 2 && labels[len(labels)-1] == "service"
}

func isNode(labels []string) bool {
	return len(labels) >= 2 && labels[len(labels)-1] == "node"
}

// answerNode handles <name>.node.<domain> queries (A only).
func (d *DNSServer) answerNode(resp *dns.Msg, q dns.Question, labels []string) {
	// labels = [name, "node"]. anything more specific is an NXDOMAIN.
	if len(labels) != 2 {
		resp.Rcode = dns.RcodeNameError
		return
	}
	name := labels[0]
	member, ok := d.findMember(name)
	if !ok {
		resp.Rcode = dns.RcodeNameError
		return
	}
	if q.Qtype == dns.TypeA || q.Qtype == dns.TypeANY {
		if ip := member.Addr.To4(); ip != nil {
			resp.Answer = append(resp.Answer, &dns.A{
				Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0},
				A:   ip,
			})
		}
	}
}

// answerService handles <svc>.service.<domain> and
// <tag>.<svc>.service.<domain> queries (A + SRV).
func (d *DNSServer) answerService(resp *dns.Msg, q dns.Question, labels []string) {
	// shapes: [name, "service"]  or  [tag, name, "service"]
	var name, tag string
	switch len(labels) {
	case 2:
		name = labels[0]
	case 3:
		tag = labels[0]
		name = labels[1]
	default:
		resp.Rcode = dns.RcodeNameError
		return
	}

	instances := d.agent.CatalogServiceInstances(name)
	if tag != "" {
		instances = filterByTag(instances, tag)
	}
	if len(instances) == 0 {
		resp.Rcode = dns.RcodeNameError
		return
	}

	wantA := q.Qtype == dns.TypeA || q.Qtype == dns.TypeANY
	wantSRV := q.Qtype == dns.TypeSRV || q.Qtype == dns.TypeANY

	for _, inst := range instances {
		ip := d.serviceIP(inst)
		if ip == nil {
			continue
		}
		if wantA {
			resp.Answer = append(resp.Answer, &dns.A{
				Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0},
				A:   ip,
			})
		}
		if wantSRV {
			target := dns.Fqdn(fmt.Sprintf("%s.node.%s", normalizeLabel(inst.Node), d.domain))
			resp.Answer = append(resp.Answer, &dns.SRV{
				Hdr:      dns.RR_Header{Name: q.Name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: 0},
				Priority: 1,
				Weight:   1,
				Port:     uint16(inst.Service.Port),
				Target:   target,
			})
			// add an A record in Extra so clients don't need a second lookup
			resp.Extra = append(resp.Extra, &dns.A{
				Hdr: dns.RR_Header{Name: target, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0},
				A:   ip,
			})
		}
	}
}

// serviceIP resolves a NodeService to an IP, honoring the ServiceAddress
// fallback convention (empty service addr → use node addr).
func (d *DNSServer) serviceIP(inst NodeService) net.IP {
	if inst.Service.Address != "" {
		if ip := net.ParseIP(inst.Service.Address); ip != nil {
			return ip.To4()
		}
	}
	if m, ok := d.findMember(inst.Node); ok {
		return m.Addr.To4()
	}
	return nil
}

func (d *DNSServer) findMember(name string) (serf.Member, bool) {
	target := strings.ToLower(name)
	for _, m := range d.agent.serf.Members() {
		if strings.ToLower(normalizeLabel(m.Name)) == target {
			return m, true
		}
	}
	return serf.Member{}, false
}

func filterByTag(instances []NodeService, tag string) []NodeService {
	tag = strings.ToLower(tag)
	out := instances[:0:0]
	for _, inst := range instances {
		for _, t := range inst.Service.Tags {
			if strings.ToLower(t) == tag {
				out = append(out, inst)
				break
			}
		}
	}
	return out
}

// normalizeLabel makes a node or service name safe to embed in a DNS label.
// Currently just lowercases and replaces spaces with hyphens, which covers
// the common cases (hostnames are already valid, but `cascade agent` test
// nodes are named "agent 1" etc.).
func normalizeLabel(s string) string {
	s = strings.ToLower(s)
	s = strings.ReplaceAll(s, " ", "-")
	return s
}