diff --git a/client/internal/connect.go b/client/internal/connect.go index c2b7db0e2..f20b8d361 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -275,11 +275,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engineMutex.Unlock() - if err := c.engine.Start(); err != nil { + if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go new file mode 100644 index 000000000..cb651f1e5 --- /dev/null +++ b/client/internal/dns/config/domains.go @@ -0,0 +1,201 @@ +package config + +import ( + "errors" + "fmt" + "net" + "net/netip" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/shared/management/domain" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +var ( + ErrEmptyURL = errors.New("empty URL") + ErrEmptyHost = errors.New("empty host") + ErrIPNotAllowed = errors.New("IP address not allowed") +) + +// ServerDomains represents the management server domains extracted from NetBird configuration +type ServerDomains struct { + Signal domain.Domain + Relay []domain.Domain + Flow domain.Domain + Stuns []domain.Domain + Turns []domain.Domain +} + +// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration +func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { + if config == nil { + return ServerDomains{} + } + + domains := ServerDomains{} + + domains.Signal = extractSignalDomain(config) + domains.Relay = extractRelayDomains(config) + domains.Flow = extractFlowDomain(config) + domains.Stuns = extractStunDomains(config) + domains.Turns = extractTurnDomains(config) + + return domains +} + +// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses +func ExtractValidDomain(rawURL string) (domain.Domain, error) { + if rawURL == "" { + return "", ErrEmptyURL + } + + parsedURL, err := url.Parse(rawURL) + if err == nil { + if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" { + return domain, err + } + } + + return extractFromRawString(rawURL) +} + +// extractFromParsedURL handles domain extraction from successfully parsed URLs +func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) { + if parsedURL.Hostname() != "" { + return extractDomainFromHost(parsedURL.Hostname()) + } + + if parsedURL.Opaque == "" || parsedURL.Scheme == "" { + return "", nil + } + + // Handle URLs with opaque content (e.g., stun:host:port) + if strings.Contains(parsedURL.Scheme, ".") { + // This is likely "domain.com:port" being parsed as scheme:opaque + reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque + if host, _, err := net.SplitHostPort(reconstructed); err == nil { + return extractDomainFromHost(host) + } + return extractDomainFromHost(parsedURL.Scheme) + } + + // Valid scheme with opaque content (e.g., stun:host:port) + host := parsedURL.Opaque + if queryIndex := strings.Index(host, "?"); queryIndex > 0 { + host = host[:queryIndex] + } + + if hostOnly, _, err := net.SplitHostPort(host); err == nil { + return extractDomainFromHost(hostOnly) + } + + return extractDomainFromHost(host) +} + +// extractFromRawString handles domain extraction when URL parsing fails or returns no results +func extractFromRawString(rawURL string) (domain.Domain, error) { + if host, _, err := net.SplitHostPort(rawURL); err == nil { + return extractDomainFromHost(host) + } + + return extractDomainFromHost(rawURL) +} + +// extractDomainFromHost extracts domain from a host string, filtering out IP addresses +func extractDomainFromHost(host string) (domain.Domain, error) { + if host == "" { + return "", ErrEmptyHost + } + + if _, err := netip.ParseAddr(host); err == nil { + return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) + } + + d, err := domain.FromString(host) + if err != nil { + return "", fmt.Errorf("invalid domain: %v", err) + } + + return d, nil +} + +// extractSingleDomain extracts a single domain from a URL with error logging +func extractSingleDomain(url, serviceType string) domain.Domain { + if url == "" { + return "" + } + + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + return "" + } + + return d +} + +// extractMultipleDomains extracts multiple domains from URLs with error logging +func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { + var domains []domain.Domain + for _, url := range urls { + if url == "" { + continue + } + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + continue + } + domains = append(domains, d) + } + return domains +} + +// extractSignalDomain extracts the signal domain from NetBird configuration. +func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Signal != nil { + return extractSingleDomain(config.Signal.Uri, "signal") + } + return "" +} + +// extractRelayDomains extracts relay server domains from NetBird configuration. +func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Relay != nil { + return extractMultipleDomains(config.Relay.Urls, "relay") + } + return nil +} + +// extractFlowDomain extracts the traffic flow domain from NetBird configuration. +func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Flow != nil { + return extractSingleDomain(config.Flow.Url, "flow") + } + return "" +} + +// extractStunDomains extracts STUN server domains from NetBird configuration. +func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, stun := range config.Stuns { + if stun != nil && stun.Uri != "" { + urls = append(urls, stun.Uri) + } + } + return extractMultipleDomains(urls, "STUN") +} + +// extractTurnDomains extracts TURN server domains from NetBird configuration. +func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, turn := range config.Turns { + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { + urls = append(urls, turn.HostConfig.Uri) + } + } + return extractMultipleDomains(urls, "TURN") +} diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go new file mode 100644 index 000000000..5eae3a541 --- /dev/null +++ b/client/internal/dns/config/domains_test.go @@ -0,0 +1,213 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractValidDomain(t *testing.T) { + tests := []struct { + name string + url string + expected string + expectError bool + }{ + { + name: "HTTPS URL with port", + url: "https://api.netbird.io:443", + expected: "api.netbird.io", + }, + { + name: "HTTP URL without port", + url: "http://signal.example.com", + expected: "signal.example.com", + }, + { + name: "Host with port (no scheme)", + url: "signal.netbird.io:443", + expected: "signal.netbird.io", + }, + { + name: "STUN URL", + url: "stun:stun.netbird.io:443", + expected: "stun.netbird.io", + }, + { + name: "STUN URL with different port", + url: "stun:stun.netbird.io:5555", + expected: "stun.netbird.io", + }, + { + name: "TURNS URL with query params", + url: "turns:turn.netbird.io:443?transport=tcp", + expected: "turn.netbird.io", + }, + { + name: "TURN URL", + url: "turn:turn.example.com:3478", + expected: "turn.example.com", + }, + { + name: "REL URL", + url: "rel://relay.example.com:443", + expected: "relay.example.com", + }, + { + name: "RELS URL", + url: "rels://relay.netbird.io:443", + expected: "relay.netbird.io", + }, + { + name: "Raw hostname", + url: "example.org", + expected: "example.org", + }, + { + name: "IP address should be rejected", + url: "192.168.1.1", + expectError: true, + }, + { + name: "IP address with port should be rejected", + url: "192.168.1.1:443", + expectError: true, + }, + { + name: "IPv6 address should be rejected", + url: "2001:db8::1", + expectError: true, + }, + { + name: "HTTP URL with IPv4 should be rejected", + url: "http://192.168.1.1:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv4 should be rejected", + url: "https://10.0.0.1:443", + expectError: true, + }, + { + name: "STUN URL with IPv4 should be rejected", + url: "stun:192.168.1.1:3478", + expectError: true, + }, + { + name: "TURN URL with IPv4 should be rejected", + url: "turn:10.0.0.1:3478", + expectError: true, + }, + { + name: "TURNS URL with IPv4 should be rejected", + url: "turns:172.16.0.1:5349", + expectError: true, + }, + { + name: "HTTP URL with IPv6 should be rejected", + url: "http://[2001:db8::1]:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv6 should be rejected", + url: "https://[::1]:443", + expectError: true, + }, + { + name: "STUN URL with IPv6 should be rejected", + url: "stun:[2001:db8::1]:3478", + expectError: true, + }, + { + name: "IPv6 with port should be rejected", + url: "[2001:db8::1]:443", + expectError: true, + }, + { + name: "Localhost IPv4 should be rejected", + url: "127.0.0.1:8080", + expectError: true, + }, + { + name: "Localhost IPv6 should be rejected", + url: "[::1]:443", + expectError: true, + }, + { + name: "REL URL with IPv4 should be rejected", + url: "rel://192.168.1.1:443", + expectError: true, + }, + { + name: "RELS URL with IPv4 should be rejected", + url: "rels://10.0.0.1:443", + expectError: true, + }, + { + name: "Empty URL", + url: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ExtractValidDomain(tt.url) + + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Unexpected error for URL: %s", tt.url) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url) + } + }) + } +} + +func TestExtractDomainFromHost(t *testing.T) { + tests := []struct { + name string + host string + expected string + expectError bool + }{ + { + name: "Valid domain", + host: "example.com", + expected: "example.com", + }, + { + name: "Subdomain", + host: "api.example.com", + expected: "api.example.com", + }, + { + name: "IPv4 address", + host: "192.168.1.1", + expectError: true, + }, + { + name: "IPv6 address", + host: "2001:db8::1", + expectError: true, + }, + { + name: "Empty host", + host: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractDomainFromHost(tt.host) + + if tt.expectError { + assert.Error(t, err, "Expected error for host: %s", tt.host) + } else { + assert.NoError(t, err, "Unexpected error for host: %s", tt.host) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host) + } + }) + } +} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 439bcbb3c..2e54bffd9 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,11 +11,12 @@ import ( ) const ( - PriorityLocal = 100 - PriorityDNSRoute = 75 - PriorityUpstream = 50 - PriorityDefault = 1 - PriorityFallback = -100 + PriorityMgmtCache = 150 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants to continue, try next handler if chainWriter.shouldContinue { - log.Tracef("handler requested continue to next handler for domain=%s", qname) + // Only log continue for non-management cache handlers to reduce noise + if entry.Priority != PriorityMgmtCache { + log.Tracef("handler requested continue to next handler for domain=%s", qname) + } continue } return diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b776fbbe3..bac7875ec 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool { // String returns a string representation of the local resolver func (d *Resolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.records)) + return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) } func (d *Resolver) Stop() {} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go new file mode 100644 index 000000000..290395473 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt.go @@ -0,0 +1,360 @@ +package mgmt + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/shared/management/domain" +) + +const dnsTimeout = 5 * time.Second + +// Resolver caches critical NetBird infrastructure domains +type Resolver struct { + records map[dns.Question][]dns.RR + mgmtDomain *domain.Domain + serverDomains *dnsconfig.ServerDomains + mutex sync.RWMutex +} + +// NewResolver creates a new management domains cache resolver. +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +// String returns a string representation of the resolver. +func (m *Resolver) String() string { + return "MgmtCacheResolver" +} + +// ServeDNS implements dns.Handler interface. +func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + m.continueToNext(w, r) + return + } + + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { + m.continueToNext(w, r) + return + } + + m.mutex.RLock() + records, found := m.records[question] + m.mutex.RUnlock() + + if !found { + m.continueToNext(w, r) + return + } + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + + resp.Answer = append(resp.Answer, records...) + + log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write response: %v", err) + } +} + +// MatchSubdomains returns false since this resolver only handles exact domain matches +// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. +func (m *Resolver) MatchSubdomains() bool { + return false +} + +// continueToNext signals the handler chain to continue to the next handler. +func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write continue signal: %v", err) + } +} + +// AddDomain manually adds a domain to cache by resolving it. +func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + if err != nil { + return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + } + + var aRecords, aaaaRecords []dns.RR + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) + } + } + + m.mutex.Lock() + + if len(aRecords) > 0 { + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + m.records[aQuestion] = aRecords + } + + if len(aaaaRecords) > 0 { + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + m.records[aaaaQuestion] = aaaaRecords + } + + m.mutex.Unlock() + + log.Debugf("added domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + + return nil +} + +// PopulateFromConfig extracts and caches domains from the client configuration. +func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { + if mgmtURL == nil { + return nil + } + + d, err := dnsconfig.ExtractValidDomain(mgmtURL.String()) + if err != nil { + return fmt.Errorf("extract domain from URL: %w", err) + } + + m.mutex.Lock() + m.mgmtDomain = &d + m.mutex.Unlock() + + if err := m.AddDomain(ctx, d); err != nil { + return fmt.Errorf("add domain: %w", err) + } + + return nil +} + +// RemoveDomain removes a domain from the cache. +func (m *Resolver) RemoveDomain(d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + m.mutex.Lock() + defer m.mutex.Unlock() + + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + delete(m.records, aQuestion) + + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + delete(m.records, aaaaQuestion) + + log.Debugf("removed domain=%s from cache", d.SafeString()) + return nil +} + +// GetCachedDomains returns a list of all cached domains. +func (m *Resolver) GetCachedDomains() domain.List { + m.mutex.RLock() + defer m.mutex.RUnlock() + + domainSet := make(map[domain.Domain]struct{}) + for question := range m.records { + domainName := strings.TrimSuffix(question.Name, ".") + domainSet[domain.Domain(domainName)] = struct{}{} + } + + domains := make(domain.List, 0, len(domainSet)) + for d := range domainSet { + domains = append(domains, d) + } + + return domains +} + +// UpdateFromServerDomains updates the cache with server domains from network configuration. +// It merges new domains with existing ones, replacing entire domain types when updated. +// Empty updates are ignored to prevent clearing infrastructure domains during partial updates. +func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) { + newDomains := m.extractDomainsFromServerDomains(serverDomains) + var removedDomains domain.List + + if len(newDomains) > 0 { + m.mutex.Lock() + if m.serverDomains == nil { + m.serverDomains = &dnsconfig.ServerDomains{} + } + updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains) + m.serverDomains = &updatedServerDomains + m.mutex.Unlock() + + allDomains := m.extractDomainsFromServerDomains(updatedServerDomains) + currentDomains := m.GetCachedDomains() + removedDomains = m.removeStaleDomains(currentDomains, allDomains) + } + + m.addNewDomains(ctx, newDomains) + + return removedDomains, nil +} + +// removeStaleDomains removes cached domains not present in the target domain list. +// Management domains are preserved and never removed during server domain updates. +func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List { + var removedDomains domain.List + + for _, currentDomain := range currentDomains { + if m.isDomainInList(currentDomain, newDomains) { + continue + } + + if m.isManagementDomain(currentDomain) { + continue + } + + removedDomains = append(removedDomains, currentDomain) + if err := m.RemoveDomain(currentDomain); err != nil { + log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) + } + } + + return removedDomains +} + +// mergeServerDomains merges new server domains with existing ones. +// When a domain type is provided in the new domains, it completely replaces that type. +func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains { + merged := existing + + if incoming.Signal != "" { + merged.Signal = incoming.Signal + } + if len(incoming.Relay) > 0 { + merged.Relay = incoming.Relay + } + if incoming.Flow != "" { + merged.Flow = incoming.Flow + } + if len(incoming.Stuns) > 0 { + merged.Stuns = incoming.Stuns + } + if len(incoming.Turns) > 0 { + merged.Turns = incoming.Turns + } + + return merged +} + +// isDomainInList checks if domain exists in the list +func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool { + for _, d := range list { + if domain.SafeString() == d.SafeString() { + return true + } + } + return false +} + +// isManagementDomain checks if domain is the protected management domain +func (m *Resolver) isManagementDomain(domain domain.Domain) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return m.mgmtDomain != nil && domain == *m.mgmtDomain +} + +// addNewDomains resolves and caches all domains from the update +func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { + for _, newDomain := range newDomains { + if err := m.AddDomain(ctx, newDomain); err != nil { + log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) + } else { + log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) + } + } +} + +func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List { + var domains domain.List + + if serverDomains.Signal != "" { + domains = append(domains, serverDomains.Signal) + } + + for _, relay := range serverDomains.Relay { + if relay != "" { + domains = append(domains, relay) + } + } + + if serverDomains.Flow != "" { + domains = append(domains, serverDomains.Flow) + } + + for _, stun := range serverDomains.Stuns { + if stun != "" { + domains = append(domains, stun) + } + } + + for _, turn := range serverDomains.Turns { + if turn != "" { + domains = append(domains, turn) + } + } + + return domains +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go new file mode 100644 index 000000000..99d289871 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -0,0 +1,416 @@ +package mgmt + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +func TestResolver_NewResolver(t *testing.T) { + resolver := NewResolver() + + assert.NotNil(t, resolver) + assert.NotNil(t, resolver.records) + assert.False(t, resolver.MatchSubdomains()) +} + +func TestResolver_ExtractDomainFromURL(t *testing.T) { + tests := []struct { + name string + urlStr string + expectedDom string + expectError bool + }{ + { + name: "HTTPS URL with port", + urlStr: "https://api.netbird.io:443", + expectedDom: "api.netbird.io", + expectError: false, + }, + { + name: "HTTP URL without port", + urlStr: "http://signal.example.com", + expectedDom: "signal.example.com", + expectError: false, + }, + { + name: "URL with path", + urlStr: "https://relay.netbird.io/status", + expectedDom: "relay.netbird.io", + expectError: false, + }, + { + name: "Invalid URL", + urlStr: "not-a-valid-url", + expectedDom: "not-a-valid-url", + expectError: false, + }, + { + name: "Empty URL", + urlStr: "", + expectedDom: "", + expectError: true, + }, + { + name: "STUN URL", + urlStr: "stun:stun.example.com:3478", + expectedDom: "stun.example.com", + expectError: false, + }, + { + name: "TURN URL", + urlStr: "turn:turn.example.com:3478", + expectedDom: "turn.example.com", + expectError: false, + }, + { + name: "REL URL", + urlStr: "rel://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + { + name: "RELS URL", + urlStr: "rels://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parsedURL *url.URL + var err error + + if tt.urlStr != "" { + parsedURL, err = url.Parse(tt.urlStr) + if err != nil && !tt.expectError { + t.Fatalf("Failed to parse URL: %v", err) + } + } + + domain, err := extractDomainFromURL(parsedURL) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedDom, domain.SafeString()) + } + }) + } +} + +func TestResolver_PopulateFromConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := NewResolver() + + // Test with IP address - should return error since IP addresses are rejected + mgmtURL, _ := url.Parse("https://127.0.0.1") + + err := resolver.PopulateFromConfig(ctx, mgmtURL) + assert.Error(t, err) + assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed) + + // No domains should be cached when using IP addresses + domains := resolver.GetCachedDomains() + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") +} + +func TestResolver_ServeDNS(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Add a test domain to the cache - use example.org which is reserved for testing + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Test A record query for cached domain + t.Run("Cached domain A record", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer") + }) + + // Test uncached domain signals to continue to next handler + t.Run("Uncached domain signals continue to next handler", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("unknown.example.com.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + // Zero flag set to true signals the handler chain to continue to next handler + assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain") + }) + + // Test that subdomains of cached domains are NOT resolved + t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query for a subdomain of our cached domain + req := new(dns.Msg) + req.SetQuestion("sub.example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains") + }) + + // Test case-insensitive matching + t.Run("Case-insensitive domain matching", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query with different casing + req := new(dns.Msg) + req.SetQuestion("EXAMPLE.ORG.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case") + }) +} + +func TestResolver_GetCachedDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + cachedDomains := resolver.GetCachedDomains() + + assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain") + assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original") + assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot") +} + +func TestResolver_ManagementDomainProtection(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + mgmtURL, _ := url.Parse("https://example.org") + err := resolver.PopulateFromConfig(ctx, mgmtURL) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + initialDomains := resolver.GetCachedDomains() + if len(initialDomains) == 0 { + t.Skip("Management domain failed to resolve, skipping test") + } + assert.Equal(t, 1, len(initialDomains), "Should have management domain cached") + assert.Equal(t, "example.org", initialDomains[0].SafeString()) + + serverDomains := dnsconfig.ServerDomains{ + Signal: "google.com", + Relay: []domain.Domain{"cloudflare.com"}, + } + + _, err = resolver.UpdateFromServerDomains(ctx, serverDomains) + if err != nil { + t.Logf("Server domains update failed: %v", err) + } + + finalDomains := resolver.GetCachedDomains() + + managementStillCached := false + for _, d := range finalDomains { + if d.SafeString() == "example.org" { + managementStillCached = true + break + } + } + assert.True(t, managementStillCached, "Management domain should never be removed") +} + +// extractDomainFromURL extracts a domain from a URL - test helper function +func extractDomainFromURL(u *url.URL) (domain.Domain, error) { + if u == nil { + return "", fmt.Errorf("URL is nil") + } + return dnsconfig.ExtractValidDomain(u.String()) +} + +func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Verify domains were added + cachedDomains := resolver.GetCachedDomains() + assert.Len(t, cachedDomains, 3) + + // Update with empty ServerDomains (simulating partial network map update) + emptyDomains := dnsconfig.ServerDomains{} + removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains) + assert.NoError(t, err) + + // Verify no domains were removed + assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty") + + // Verify all original domains are still cached + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "All original domains should still be cached") +} + +func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn) + partialDomains := dnsconfig.ServerDomains{ + Signal: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Should remove only the old signal domain + assert.Len(t, removedDomains, 1, "Should remove only the old signal domain") + assert.Equal(t, "example.org", removedDomains[0].SafeString()) + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "github.com") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.NotContains(t, domainStrings, "example.org") +} + +func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only flow domain - new type, should preserve all existing) + partialDomains := dnsconfig.ServerDomains{ + Flow: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "example.org") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.Contains(t, domainStrings, "github.com") +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index d160fa99a..0f89b9016 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,20 +3,23 @@ package dns import ( "fmt" "net/netip" + "net/url" "github.com/miekg/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/shared/management/domain" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func(domain.List, dns.Handler, int) - DeregisterHandlerFunc func(domain.List, int) + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) + UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error } func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string { // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface func (m *MockServer) ProbeAvailability() { } + +func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + if m.UpdateServerConfigFunc != nil { + return m.UpdateServerConfigFunc(domains) + } + return nil +} + +func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { + return nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index cbcf6a256..8cb886203 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "net/url" "runtime" "strings" "sync" @@ -15,7 +16,9 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -45,6 +48,8 @@ type Server interface { OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string ProbeAvailability() + UpdateServerConfig(domains dnsconfig.ServerDomains) error + PopulateManagementDomain(mgmtURL *url.URL) error } type nsGroupsByDomain struct { @@ -77,6 +82,8 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int + mgmtCacheResolver *mgmt.Resolver + // permanent related properties permanent bool hostsDNSHolder *hostsDNSHolder @@ -104,18 +111,20 @@ type handlerWrapper struct { type registeredHandlerMap map[types.HandlerID]handlerWrapper +// DefaultServerConfig holds configuration parameters for NewDefaultServer +type DefaultServerConfig struct { + WgInterface WGIface + CustomAddress string + StatusRecorder *peer.Status + StateManager *statemanager.Manager + DisableSys bool +} + // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, - stateManager *statemanager.Manager, - disableSys bool, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) { var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if config.CustomAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) if err != nil { return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) } @@ -123,13 +132,14 @@ func NewDefaultServer( } var dnsService service - if wgInterface.IsUserspaceBind() { - dnsService = NewServiceViaMemory(wgInterface) + if config.WgInterface.IsUserspaceBind() { + dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(wgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil + server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) + return server, nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -178,20 +188,24 @@ func newDefaultServer( ) *DefaultServer { handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) + + mgmtCacheResolver := mgmt.NewResolver() + defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: local.NewResolver(), - wgInterface: wgInterface, - statusRecorder: statusRecorder, - stateManager: stateManager, - hostsDNSHolder: newHostsDNSHolder(), - hostManager: &noopHostConfigurator{}, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), + wgInterface: wgInterface, + statusRecorder: statusRecorder, + stateManager: stateManager, + hostsDNSHolder: newHostsDNSHolder(), + hostManager: &noopHostConfigurator{}, + mgmtCacheResolver: mgmtCacheResolver, } // register with root zone, handler chain takes care of the routing @@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { - log.Debugf("registering handler %s with priority %d", handler, priority) + log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains) for _, domain := range domains { if domain == "" { @@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { - log.Debugf("deregistering handler %v with priority %d", domains, priority) + log.Debugf("deregistering handler with priority %d for %v", priority, domains) for _, domain := range domains { if domain == "" { @@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } +func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + s.mux.Lock() + defer s.mux.Unlock() + + if s.mgmtCacheResolver != nil { + removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) + if err != nil { + return fmt.Errorf("update management cache resolver: %w", err) + } + + if len(removedDomains) > 0 { + s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) + } + + newDomains := s.mgmtCacheResolver.GetCachedDomains() + if len(newDomains) > 0 { + s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) + } + } + + return nil +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver if update.ServiceEnable { @@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain { ), ) } + +// PopulateManagementDomain populates the DNS cache with management domain +func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { + if s.mgmtCacheResolver != nil { + return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) + } + return nil +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 068f001d8..11575d500 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatal(err) } @@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: &mocWGIface{}, + CustomAddress: testCase.addrPort, + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 071e3617a..c19e0acb5 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) { } const ( - UpstreamTimeout = 15 * time.Second + UpstreamTimeout = 4 * time.Second + // ClientTimeout is the timeout for the dns.Client. + // Set longer than UpstreamTimeout to ensure context timeout takes precedence + ClientTimeout = 5 * time.Second - failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second ) @@ -58,9 +60,7 @@ type upstreamResolverBase struct { upstreamServers []netip.AddrPort domain string disabled bool - failsCount atomic.Int32 successCount atomic.Int32 - failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration @@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain: domain, upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.upstreamServers) } // ID returns the unique handler ID @@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { requestID := GenerateRequestID() logger := log.WithField("request_id", requestID) - var err error - defer func() { - u.checkUpstreamFails(err) - }() logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + u.prepareRequest(r) + + if u.ctx.Err() != nil { + logger.Tracef("%s has been stopped", u) + return + } + + if u.tryUpstreamServers(w, r, logger) { + return + } + + u.writeErrorResponse(w, r, logger) +} + +func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } +} - select { - case <-u.ctx.Done(): - logger.Tracef("%s has been stopped", u) - return - default: +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { + timeout := u.upstreamTimeout + if len(u.upstreamServers) > 1 { + maxTotal := 5 * time.Second + minPerUpstream := 2 * time.Second + scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) + if scaledTimeout > minPerUpstream { + timeout = scaledTimeout + } else { + timeout = minPerUpstream + } } for _, upstream := range u.upstreamServers { - var rm *dns.Msg - var t time.Duration - - func() { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - defer cancel() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) - }() - - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) - continue - } - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) - continue + if u.queryUpstream(w, r, upstream, timeout, logger) { + return true } + } + return false +} - if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - continue - } +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { + var rm *dns.Msg + var t time.Duration + var err error - u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + var startTime time.Time + func() { + ctx, cancel := context.WithTimeout(u.ctx, timeout) + defer cancel() + startTime = time.Now() + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + }() - if err = w.WriteMsg(rm); err != nil { - logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) - } - // count the fails only if they happen sequentially - u.failsCount.Store(0) + if err != nil { + u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) + return false + } + + if rm == nil || !rm.Response { + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + return false + } + + return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) +} + +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { + if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) return } - u.failsCount.Add(1) + + elapsed := time.Since(startTime) + timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { + timeoutMsg += " " + peerInfo + } + timeoutMsg += fmt.Sprintf(" - error: %v", err) + logger.Warnf(timeoutMsg) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { + u.successCount.Add(1) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) + + if err := w.WriteMsg(rm); err != nil { + logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) + } + return true +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) @@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -// checkUpstreamFails counts fails and disables or enables upstream resolving -// -// If fails count is greater that failsTillDeact, upstream resolving -// will be disabled for reactivatePeriod, after that time period fails counter -// will be reset and upstream will be reactivated. -func (u *upstreamResolverBase) checkUpstreamFails(err error) { - u.mutex.Lock() - defer u.mutex.Unlock() - - if u.failsCount.Load() < u.failsTillDeact || u.disabled { - return - } - - select { - case <-u.ctx.Done(): - return - default: - } - - u.disable(err) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (fail count exceeded)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - // TODO add domain meta - ) -} - // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() { default: } - // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact - if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { + // avoid probe if upstreams could resolve at least one query + if u.successCount.Load() > 0 { return } @@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() { } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.failsCount.Store(0) u.successCount.Add(1) u.reactivate() u.disabled = false @@ -416,3 +423,80 @@ func GenerateRequestID() string { } return hex.EncodeToString(bytes) } + +// FormatPeerStatus formats peer connection status information for debugging DNS timeouts +func FormatPeerStatus(peerState *peer.State) string { + isConnected := peerState.ConnStatus == peer.StatusConnected + hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() && + time.Since(peerState.LastWireguardHandshake) < 3*time.Minute + + statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP) + + switch { + case !isConnected: + statusInfo += " DISCONNECTED" + case !hasRecentHandshake: + statusInfo += " NO_RECENT_HANDSHAKE" + default: + statusInfo += " connected" + } + + if !peerState.LastWireguardHandshake.IsZero() { + timeSinceHandshake := time.Since(peerState.LastWireguardHandshake) + statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second)) + } else { + statusInfo += " no_handshake" + } + + if peerState.Relayed { + statusInfo += " via_relay" + } + + if peerState.Latency > 0 { + statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency) + } + + return statusInfo +} + +// findPeerForIP finds which peer handles the given IP address +func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { + if statusRecorder == nil { + return nil + } + + fullStatus := statusRecorder.GetFullStatus() + var bestMatch *peer.State + var bestPrefixLen int + + for _, peerState := range fullStatus.Peers { + routes := peerState.GetRoutes() + for route := range routes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + continue + } + + if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen { + peerStateCopy := peerState + bestMatch = &peerStateCopy + bestPrefixLen = prefix.Bits() + } + } + } + + return bestMatch +} + +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ddbf84ae4..6b7dcc05e 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns } func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{} + upstreamExchangeClient := &dns.Client{ + Timeout: ClientTimeout, + } return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } @@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri } upstreamExchangeClient := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: timeout, } - return upstreamExchangeClient.Exchange(r, upstream) + return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 317588a27..434e5880b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -34,7 +34,10 @@ func newUpstreamResolver( } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) + client := &dns.Client{ + Timeout: ClientTimeout, + } + return ExchangeWithFallback(ctx, client, r, upstream) } func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 96b8bbb0f..eadcdd117 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -47,7 +47,9 @@ func newUpstreamResolver( } func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - client := &dns.Client{} + client := &dns.Client{ + Timeout: ClientTimeout, + } upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) @@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura }, } client := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: dialTimeout, } return client, nil } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 51d870e2a..e1573e75e 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) } func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { + mockClient := &mockUpstreamResolver{ + err: dns.ErrTime, + r: new(dns.Msg), + rtt: time.Millisecond, + } + resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: &mockUpstreamResolver{ - err: nil, - r: new(dns.Msg), - rtt: time.Millisecond, - }, + ctx: context.TODO(), + upstreamClient: mockClient, upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, + reactivatePeriod: time.Microsecond * 100, } addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - resolver.failsTillDeact = 0 - resolver.reactivatePeriod = time.Microsecond * 100 - - responseWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { return nil }, - } failed := false resolver.deactivate = func(error) { failed = true + // After deactivation, make the mock client work again + mockClient.err = nil } reactivated := false @@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) + resolver.ProbeAvailability() if !failed { t.Errorf("expected that resolving was deactivated") @@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { return } - if resolver.failsCount.Load() != 0 { - t.Errorf("fails count after reactivation should be 0") - return - } - if resolver.disabled { t.Errorf("should be enabled") } diff --git a/client/internal/engine.go b/client/internal/engine.go index c1c67207b..c610de41e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "net/netip" + "net/url" "os" "reflect" "runtime" @@ -33,6 +34,7 @@ import ( nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -345,7 +347,7 @@ func (e *Engine) Stop() error { // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service -func (e *Engine) Start() error { +func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -401,6 +403,11 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer + // Populate DNS cache with NetbirdConfig and management URL for early resolution + if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache: %v", err) + } + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ Context: e.ctx, PublicKey: e.config.WgPrivateKey.PublicKey().String(), @@ -661,6 +668,30 @@ func (e *Engine) removePeer(peerKey string) error { return nil } +// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response +func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { + if e.dnsServer == nil { + return nil + } + + // Populate management URL if provided + if mgmtURL != nil { + if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache with management URL: %v", err) + } + } + + // Populate NetbirdConfig domains if provided + if netbirdConfig != nil { + serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig) + if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { + return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err) + } + } + + return nil +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -692,6 +723,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } + if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { + log.Warnf("Failed to update DNS server config: %v", err) + } + // todo update signal } @@ -1557,7 +1592,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) + + dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ + WgInterface: e.wgInterface, + CustomAddress: e.config.CustomDNSAddress, + StatusRecorder: e.statusRecorder, + StateManager: e.stateManager, + DisableSys: e.config.DisableDNS, + }) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9d6361bc0..fc58dbdba 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -266,7 +266,7 @@ func TestEngine_SSH(t *testing.T) { }, }, nil } - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) } @@ -612,7 +612,7 @@ func TestEngine_Sync(t *testing.T) { } }() - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) return @@ -1069,7 +1069,7 @@ func TestEngine_MultiplePeers(t *testing.T) { defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) wg.Done() diff --git a/client/internal/login.go b/client/internal/login.go index d5412a110..257e3c3ac 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, return false, err } - _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) + _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) if isLoginNeeded(err) { return true, nil } @@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string, return err } - serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) + serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) if serverKey != nil && isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) + if err != nil { + return err + } + } else if err != nil { return err } - return err + return nil } func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { @@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err + return nil, nil, err } sysInfo := system.GetInfo(ctx) @@ -121,8 +125,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockInbound, config.LazyConnectionEnabled, ) - _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, err + loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) + return serverKey, loginResp, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index ba27df654..9069cdcc5 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -2,11 +2,13 @@ package dnsinterceptor import ( "context" + "errors" "fmt" "net/netip" "runtime" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -26,6 +28,8 @@ import ( "github.com/netbirdio/netbird/route" ) +const dnsTimeout = 8 * time.Second + type domainMap map[domain.Domain][]netip.Prefix type internalDNATer interface { @@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) if err != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return @@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + startTime := time.Now() + reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if err != nil { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + if errors.Is(err, context.DeadlineExceeded) { + elapsed := time.Since(startTime) + peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) + logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", + elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) + } else { + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + } if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { logger.Errorf("failed writing DNS response: %v", err) } @@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR } return } + +func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { + if d.statusRecorder == nil { + return "" + } + + peerState, err := d.statusRecorder.GetPeer(peerKey) + if err != nil { + return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err) + } + + return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState)) +}