diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index ab3e611e1..589f2001a 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/management/domain" ) // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform @@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind } } -func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { +func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) { log.Info("create tun interface") routesString := routesToString(routes) - searchDomainsToString := searchDomainsToString(searchDomains) + searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList()) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) if err != nil { diff --git a/client/iface/device_android.go b/client/iface/device_android.go index a1e246fc5..44874ab0d 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -8,10 +8,11 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/management/domain" ) type WGTunDevice interface { - Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) + Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error) Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index 5e17c6d41..9b398653c 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -2,7 +2,11 @@ package iface -import "fmt" +import ( + "fmt" + + "github.com/netbirdio/netbird/management/domain" +) // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. @@ -21,6 +25,6 @@ func (w *WGIface) Create() error { } // CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { +func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error { return fmt.Errorf("this function has not implemented on non mobile") } diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go index 373a9c95a..1dd8bbf9b 100644 --- a/client/iface/iface_create_android.go +++ b/client/iface/iface_create_android.go @@ -2,11 +2,13 @@ package iface import ( "fmt" + + "github.com/netbirdio/netbird/management/domain" ) // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. -func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { +func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/iface_create_darwin.go b/client/iface/iface_create_darwin.go index 1d91bce54..2a6424ed9 100644 --- a/client/iface/iface_create_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -7,6 +7,8 @@ import ( "time" "github.com/cenkalti/backoff/v4" + + "github.com/netbirdio/netbird/management/domain" ) // Create creates a new Wireguard interface, sets a given IP and brings it up. @@ -36,6 +38,6 @@ func (w *WGIface) Create() error { } // CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { +func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error { return fmt.Errorf("this function has not implemented on this platform") } diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 3e338267f..234f4e51d 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string { continue } - listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, ".")) + listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), ".")) } return listOfDomains } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 6baf9ed95..21f1908b0 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -7,6 +7,8 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/domain" ) const ( @@ -23,8 +25,8 @@ type SubdomainMatcher interface { type HandlerEntry struct { Handler dns.Handler Priority int - Pattern string - OrigPattern string + Pattern domain.Domain + OrigPattern domain.Domain IsWildcard bool MatchSubdomains bool } @@ -38,7 +40,7 @@ type HandlerChain struct { // ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain type ResponseWriterChain struct { dns.ResponseWriter - origPattern string + origPattern domain.Domain shouldContinue bool } @@ -58,18 +60,18 @@ func NewHandlerChain() *HandlerChain { } // GetOrigPattern returns the original pattern of the handler that wrote the response -func (w *ResponseWriterChain) GetOrigPattern() string { +func (w *ResponseWriterChain) GetOrigPattern() domain.Domain { return w.origPattern } // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority -func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) { +func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) { c.mu.Lock() defer c.mu.Unlock() - pattern = strings.ToLower(dns.Fqdn(pattern)) + pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString()))) origPattern := pattern - isWildcard := strings.HasPrefix(pattern, "*.") + isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.") if isWildcard { pattern = pattern[2:] } @@ -109,8 +111,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int { // domain specificity next if h.Priority == newEntry.Priority { - newDots := strings.Count(newEntry.Pattern, ".") - existingDots := strings.Count(h.Pattern, ".") + newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".") + existingDots := strings.Count(h.Pattern.PunycodeString(), ".") if newDots > existingDots { return i } @@ -122,20 +124,20 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int { } // RemoveHandler removes a handler for the given pattern and priority -func (c *HandlerChain) RemoveHandler(pattern string, priority int) { +func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) { c.mu.Lock() defer c.mu.Unlock() - pattern = dns.Fqdn(pattern) + pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString())) c.removeEntry(pattern, priority) } -func (c *HandlerChain) removeEntry(pattern string, priority int) { +func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) { // Find and remove handlers matching both original pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] - if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { + if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority { c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) break } @@ -169,16 +171,16 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { case entry.Pattern == ".": matched = true case entry.IsWildcard: - parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") - matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".") + matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString()) default: // For non-wildcard patterns: // If handler wants subdomain matching, allow suffix match // Otherwise require exact match if entry.MatchSubdomains { - matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern) + matched = strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString()) } else { - matched = strings.EqualFold(qname, entry.Pattern) + matched = strings.EqualFold(qname, entry.Pattern.PunycodeString()) } } diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 5f03e0758..793af2c6f 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -9,6 +9,7 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/management/domain" ) // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order @@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { tests := []struct { name string - handlerDomain string - queryDomain string + handlerDomain domain.Domain + queryDomain domain.Domain isWildcard bool matchSubdomains bool shouldMatch bool @@ -141,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { chain.AddHandler(pattern, handler, nbdns.PriorityDefault) r := new(dns.Msg) - r.SetQuestion(tt.queryDomain, dns.TypeA) + r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA) w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) @@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { tests := []struct { name string handlers []struct { - pattern string + pattern domain.Domain priority int } - queryDomain string + queryDomain domain.Domain expectedCalls int expectedHandler int // index of the handler that should be called }{ { name: "wildcard and exact same priority - exact should win", handlers: []struct { - pattern string + pattern domain.Domain priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, @@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { { name: "higher priority wildcard over lower priority exact", handlers: []struct { - pattern string + pattern domain.Domain priority int }{ {pattern: "example.com.", priority: nbdns.PriorityDefault}, @@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { { name: "multiple wildcards different priorities", handlers: []struct { - pattern string + pattern domain.Domain priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, @@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { { name: "subdomain with mix of patterns", handlers: []struct { - pattern string + pattern domain.Domain priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, @@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { { name: "root zone with specific domain", handlers: []struct { - pattern string + pattern domain.Domain priority int }{ {pattern: ".", priority: nbdns.PriorityDefault}, @@ -258,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { // Create and execute request r := new(dns.Msg) - r.SetQuestion(tt.queryDomain, dns.TypeA) + r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA) w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} chain.ServeDNS(w, r) @@ -330,7 +331,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { name string ops []struct { action string // "add" or "remove" - pattern string + pattern domain.Domain priority int } query string @@ -340,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { name: "remove high priority keeps lower priority handler", ops: []struct { action string - pattern string + pattern domain.Domain priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, @@ -357,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { name: "remove lower priority keeps high priority handler", ops: []struct { action string - pattern string + pattern domain.Domain priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, @@ -374,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { name: "remove all handlers in order", ops: []struct { action string - pattern string + pattern domain.Domain priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, @@ -436,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { func TestHandlerChain_MultiPriorityHandling(t *testing.T) { chain := nbdns.NewHandlerChain() - testDomain := "example.com." + testDomain := domain.Domain("example.com.") testQuery := "test.example.com." // Create handlers with MatchSubdomains enabled @@ -518,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name string scenario string addHandlers []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -530,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name: "case insensitive exact match", scenario: "handler registered lowercase, query uppercase", addHandlers: []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -544,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name: "case insensitive wildcard match", scenario: "handler registered mixed case wildcard, query different case", addHandlers: []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -558,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name: "multiple handlers different case same domain", scenario: "second handler should replace first despite case difference", addHandlers: []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -573,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name: "subdomain matching case insensitive", scenario: "handler with MatchSubdomains true should match regardless of case", addHandlers: []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -587,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name: "root zone case insensitive", scenario: "root zone handler should match regardless of case", addHandlers: []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -601,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { name: "multiple handlers different priority", scenario: "should call higher priority handler despite case differences", addHandlers: []struct { - pattern string + pattern domain.Domain priority int subdomains bool shouldMatch bool @@ -618,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { chain := nbdns.NewHandlerChain() - handlerCalls := make(map[string]bool) // track which patterns were called + handlerCalls := make(map[domain.Domain]bool) // track which patterns were called // Add handlers according to test case for _, h := range tt.addHandlers { @@ -686,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { scenario string ops []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool } - query string - expectedMatch string + query domain.Domain + expectedMatch domain.Domain }{ { name: "more specific domain matches first", scenario: "sub.example.com should match before example.com", ops: []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool }{ @@ -713,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { scenario: "sub.example.com should match before example.com", ops: []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool }{ @@ -728,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { scenario: "after removing most specific, should fall back to less specific", ops: []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool }{ @@ -745,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { scenario: "less specific domain with higher priority should match first", ops: []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool }{ @@ -760,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { scenario: "with equal priority, more specific domain should match", ops: []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool }{ @@ -776,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { scenario: "specific domain should match before wildcard at same priority", ops: []struct { action string - pattern string + pattern domain.Domain priority int subdomain bool }{ @@ -791,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { chain := nbdns.NewHandlerChain() - handlers := make(map[string]*nbdns.MockSubdomainHandler) + handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler) for _, op := range tt.ops { if op.action == "add" { @@ -804,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { } r := new(dns.Msg) - r.SetQuestion(tt.query, dns.TypeA) + r.SetQuestion(tt.query.PunycodeString(), dns.TypeA) w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Setup handler expectations @@ -836,9 +837,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { tests := []struct { name string - addPattern string - removePattern string - queryPattern string + addPattern domain.Domain + removePattern domain.Domain + queryPattern domain.Domain shouldBeRemoved bool description string }{ @@ -954,7 +955,7 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { handler := &nbdns.MockHandler{} r := new(dns.Msg) - r.SetQuestion(tt.queryPattern, dns.TypeA) + r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA) w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // First verify no handler is called before adding any diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index dbf0f2cfc..a301fee9c 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" ) var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") @@ -39,9 +40,9 @@ type HostDNSConfig struct { } type DomainConfig struct { - Disabled bool `json:"disabled"` - Domain string `json:"domain"` - MatchOnly bool `json:"matchOnly"` + Disabled bool `json:"disabled"` + Domain domain.Domain `json:"domain"` + MatchOnly bool `json:"matchOnly"` } type mockHostConfigurator struct { @@ -103,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD config.RouteAll = true } - for _, domain := range nsConfig.Domains { + for _, d := range nsConfig.Domains { + d := strings.ToLower(dns.Fqdn(d.PunycodeString())) config.Domains = append(config.Domains, DomainConfig{ - Domain: strings.ToLower(dns.Fqdn(domain)), + Domain: domain.Domain(d), MatchOnly: !nsConfig.SearchDomainsEnabled, }) } } for _, customZone := range dnsConfig.CustomZones { - matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) + d := strings.ToLower(dns.Fqdn(customZone.Domain)) + matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone) config.Domains = append(config.Domains, DomainConfig{ - Domain: strings.ToLower(dns.Fqdn(customZone.Domain)), + Domain: domain.Domain(d), MatchOnly: matchOnly, }) } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index a445bc6c4..18fb42d71 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * continue } if dConf.MatchOnly { - matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, ".")) + matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), ".")) continue } - searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) + searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), ".")) } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index cfba29501..de418fae5 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -100,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager continue } if !dConf.MatchOnly { - searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, ".")) + searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), ".")) } - matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) + matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain.PunycodeString(), ".")) } if len(matchDomains) != 0 { diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index c5dd6e23f..381cd3625 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -62,8 +62,8 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { return fmt.Errorf("method UpdateDNSServer is not implemented") } -func (m *MockServer) SearchDomains() []string { - return make([]string, 0) +func (m *MockServer) SearchDomains() domain.List { + return make(domain.List, 0) } // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index caae63a24..5ea44c49a 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -125,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st continue } if dConf.MatchOnly { - matchDomains = append(matchDomains, "~."+dConf.Domain) + matchDomains = append(matchDomains, "~."+dConf.Domain.PunycodeString()) continue } - searchDomains = append(searchDomains, dConf.Domain) + searchDomains = append(searchDomains, dConf.Domain.PunycodeString()) } newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic diff --git a/client/internal/dns/notifier.go b/client/internal/dns/notifier.go index 35cb6ff82..8d8351cfe 100644 --- a/client/internal/dns/notifier.go +++ b/client/internal/dns/notifier.go @@ -1,21 +1,19 @@ package dns import ( - "reflect" - "sort" "sync" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/management/domain" ) type notifier struct { listener listener.NetworkChangeListener listenerMux sync.Mutex - searchDomains []string + searchDomains domain.List } -func newNotifier(initialSearchDomains []string) *notifier { - sort.Strings(initialSearchDomains) +func newNotifier(initialSearchDomains domain.List) *notifier { return ¬ifier{ searchDomains: initialSearchDomains, } @@ -27,16 +25,8 @@ func (n *notifier) setListener(listener listener.NetworkChangeListener) { n.listener = listener } -func (n *notifier) onNewSearchDomains(searchDomains []string) { - sort.Strings(searchDomains) - - if len(n.searchDomains) != len(searchDomains) { - n.searchDomains = searchDomains - n.notify() - return - } - - if reflect.DeepEqual(n.searchDomains, searchDomains) { +func (n *notifier) onNewSearchDomains(searchDomains domain.List) { + if searchDomains.Equal(n.searchDomains) { return } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 3f49c23fd..771b00519 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -44,12 +44,12 @@ type Server interface { DnsIP() string UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(strings []string) - SearchDomains() []string + SearchDomains() domain.List ProbeAvailability() } type nsGroupsByDomain struct { - domain string + domain domain.Domain groups []*nbdns.NameServerGroup } @@ -90,7 +90,7 @@ type handlerWithStop interface { } type handlerWrapper struct { - domain string + domain domain.Domain handler handlerWithStop priority int } @@ -197,7 +197,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler s.mux.Lock() defer s.mux.Unlock() - s.registerHandler(domains.ToPunycodeList(), handler, priority) + s.registerHandler(domains, handler, priority) // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain for _, domain := range domains { @@ -207,7 +207,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler s.applyHostConfig() } -func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { +func (s *DefaultServer) registerHandler(domains domain.List, handler dns.Handler, priority int) { log.Debugf("registering handler %s with priority %d", handler, priority) for _, domain := range domains { @@ -224,7 +224,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { s.mux.Lock() defer s.mux.Unlock() - s.deregisterHandler(domains.ToPunycodeList(), priority) + s.deregisterHandler(domains, priority) for _, domain := range domains { zone := toZone(domain) s.extraDomains[zone]-- @@ -235,7 +235,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { s.applyHostConfig() } -func (s *DefaultServer) deregisterHandler(domains []string, priority int) { +func (s *DefaultServer) deregisterHandler(domains domain.List, priority int) { log.Debugf("deregistering handler %v with priority %d", domains, priority) for _, domain := range domains { @@ -378,8 +378,8 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro return nil } -func (s *DefaultServer) SearchDomains() []string { - var searchDomains []string +func (s *DefaultServer) SearchDomains() domain.List { + var searchDomains domain.List for _, dConf := range s.currentConfig.Domains { if dConf.Disabled { @@ -472,18 +472,16 @@ func (s *DefaultServer) applyHostConfig() { config := s.currentConfig - existingDomains := make(map[string]struct{}) + existingDomains := make(map[domain.Domain]struct{}) for _, d := range config.Domains { existingDomains[d.Domain] = struct{}{} } // add extra domains only if they're not already in the config - for domain := range s.extraDomains { - domainStr := domain.PunycodeString() - - if _, exists := existingDomains[domainStr]; !exists { + for d := range s.extraDomains { + if _, exists := existingDomains[d]; !exists { config.Domains = append(config.Domains, DomainConfig{ - Domain: domainStr, + Domain: d, MatchOnly: true, }) } @@ -525,7 +523,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) } muxUpdates = append(muxUpdates, handlerWrapper{ - domain: customZone.Domain, + domain: domain.Domain(customZone.Domain), handler: s.localResolver, priority: PriorityMatchDomain, }) @@ -647,7 +645,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { - s.deregisterHandler([]string{existing.domain}, existing.priority) + s.deregisterHandler(domain.List{existing.domain}, existing.priority) existing.handler.Stop() } @@ -658,7 +656,7 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { if update.domain == nbdns.RootZone { containsRootUpdate = true } - s.registerHandler([]string{update.domain}, update.handler, update.priority) + s.registerHandler(domain.List{update.domain}, update.handler, update.priority) muxUpdateMap[update.handler.ID()] = update } @@ -687,7 +685,7 @@ func (s *DefaultServer) upstreamCallbacks( handler dns.Handler, priority int, ) (deactivate func(error), reactivate func()) { - var removeIndex map[string]int + var removeIndex map[domain.Domain]int deactivate = func(err error) { s.mux.Lock() defer s.mux.Unlock() @@ -695,20 +693,20 @@ func (s *DefaultServer) upstreamCallbacks( l := log.WithField("nameservers", nsGroup.NameServers) l.Info("Temporarily deactivating nameservers group due to timeout") - removeIndex = make(map[string]int) + removeIndex = make(map[domain.Domain]int) for _, domain := range nsGroup.Domains { removeIndex[domain] = -1 } if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, priority) + s.deregisterHandler(domain.List{nbdns.RootZone}, priority) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, priority) + s.deregisterHandler(domain.List{item.Domain}, priority) removeIndex[item.Domain] = i } } @@ -732,12 +730,12 @@ func (s *DefaultServer) upstreamCallbacks( s.mux.Lock() defer s.mux.Unlock() - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { + for d, i := range removeIndex { + if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != d{ continue } s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, priority) + s.registerHandler(domain.List{d}, handler, priority) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -745,7 +743,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, priority) + s.registerHandler(domain.List{nbdns.RootZone}, handler, priority) } s.applyHostConfig() @@ -777,7 +775,7 @@ func (s *DefaultServer) addHostRootZone() { handler.deactivate = func(error) {} handler.reactivate = func() {} - s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) + s.registerHandler(domain.List{nbdns.RootZone}, handler, PriorityDefault) } func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { @@ -792,7 +790,7 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { state := peer.NSGroupState{ ID: generateGroupKey(group), Servers: servers, - Domains: group.Domains, + Domains: group.Domains.ToPunycodeList(), // The probe will determine the state, default enabled Enabled: true, Error: nil, @@ -825,7 +823,7 @@ func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { // groupNSGroupsByDomain groups nameserver groups by their match domains func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain { - domainMap := make(map[string][]*nbdns.NameServerGroup) + domainMap := make(map[domain.Domain][]*nbdns.NameServerGroup) for _, group := range nsGroups { if group.Primary { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 1c7c9b117..6364ed072 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -6,7 +6,6 @@ import ( "net" "net/netip" "os" - "strings" "testing" "time" @@ -97,7 +96,7 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } -func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { +func generateDummyHandler(domain domain.Domain, servers []nbdns.NameServer) *upstreamResolverBase { var srvs []string for _, srv := range servers { srvs = append(srvs, getNSHostPort(srv)) @@ -152,7 +151,7 @@ func TestUpdateDNSServer(t *testing.T) { }, NameServerGroups: []*nbdns.NameServerGroup{ { - Domains: []string{"netbird.io"}, + Domains: domain.List{"netbird.io"}, NameServers: nameServers, }, { @@ -184,7 +183,7 @@ func TestUpdateDNSServer(t *testing.T) { name: "New Config Should Succeed", initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ + generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, priority: PriorityMatchDomain, @@ -202,7 +201,7 @@ func TestUpdateDNSServer(t *testing.T) { }, NameServerGroups: []*nbdns.NameServerGroup{ { - Domains: []string{"netbird.io"}, + Domains: domain.List{"netbird.io"}, NameServers: nameServers, }, }, @@ -303,8 +302,8 @@ func TestUpdateDNSServer(t *testing.T) { name: "Empty Config Should Succeed and Clean Maps", initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ - domain: zoneRecords[0].Name, + generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{ + domain: domain.Domain(zoneRecords[0].Name), handler: dummyHandler, priority: PriorityMatchDomain, }, @@ -319,8 +318,8 @@ func TestUpdateDNSServer(t *testing.T) { name: "Disabled Service Should clean map", initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initUpstreamMap: registeredHandlerMap{ - generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ - domain: zoneRecords[0].Name, + generateDummyHandler(domain.Domain(zoneRecords[0].Name), nameServers).ID(): handlerWrapper{ + domain: domain.Domain(zoneRecords[0].Name), handler: dummyHandler, priority: PriorityMatchDomain, }, @@ -501,7 +500,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { dnsServer.dnsMuxMap = registeredHandlerMap{ "id1": handlerWrapper{ - domain: zoneRecords[0].Name, + domain: domain.Domain(zoneRecords[0].Name), handler: &local.Resolver{}, priority: PriorityMatchDomain, }, @@ -533,7 +532,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { }, NameServerGroups: []*nbdns.NameServerGroup{ { - Domains: []string{"netbird.io"}, + Domains: domain.List{"netbird.io"}, NameServers: nameServers, }, { @@ -599,7 +598,7 @@ func TestDNSServerStartStop(t *testing.T) { t.Error(err) } - dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) + dnsServer.registerHandler(domain.List{"netbird.cloud"}, dnsServer.localResolver, 1) resolver := &net.Resolver{ PreferGo: true, @@ -659,48 +658,48 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { var domainsUpdate string hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { - domains := []string{} + domains := domain.List{} for _, item := range config.Domains { if item.Disabled { continue } domains = append(domains, item.Domain) } - domainsUpdate = strings.Join(domains, ",") + domainsUpdate = domains.PunycodeString() return nil } deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ - Domains: []string{"domain1"}, + Domains: domain.List{"domain1"}, NameServers: []nbdns.NameServer{ {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, }, }, nil, 0) deactivate(nil) - expected := "domain0,domain2" - domains := []string{} + expected := "domain0, domain2" + domains := domain.List{} for _, item := range server.currentConfig.Domains { if item.Disabled { continue } domains = append(domains, item.Domain) } - got := strings.Join(domains, ",") + got := domains.PunycodeString() if expected != got { t.Errorf("expected domains list: %q, got %q", expected, got) } reactivate() - expected = "domain0,domain1,domain2" - domains = []string{} + expected = "domain0, domain1, domain2" + domains = domain.List{} for _, item := range server.currentConfig.Domains { if item.Disabled { continue } domains = append(domains, item.Domain) } - got = strings.Join(domains, ",") + got = domains.PunycodeString() if expected != got { t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) } @@ -868,7 +867,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { Port: 53, }, }, - Domains: []string{"google.com"}, + Domains: domain.List{"google.com"}, Primary: false, }, }, @@ -1123,7 +1122,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { name string initialHandlers registeredHandlerMap updates []handlerWrapper - expectedHandlers map[string]string // map[HandlerID]domain + expectedHandlers map[string]domain.Domain // map[HandlerID]domain description string }{ { @@ -1139,7 +1138,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityMatchDomain - 1, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-group2": "example.com", }, description: "When group1 is not included in the update, it should be removed while group2 remains", @@ -1157,7 +1156,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityMatchDomain, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-group1": "example.com", }, description: "When group2 is not included in the update, it should be removed while group1 remains", @@ -1190,7 +1189,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityMatchDomain - 1, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-group1": "example.com", "upstream-group2": "example.com", "upstream-group3": "example.com", @@ -1225,7 +1224,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityMatchDomain - 2, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-group1": "example.com", "upstream-group2": "example.com", "upstream-group3": "example.com", @@ -1245,7 +1244,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityDefault - 1, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-root2": ".", }, description: "When root1 is not included in the update, it should be removed while root2 remains", @@ -1262,7 +1261,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityDefault, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-root1": ".", }, description: "When root2 is not included in the update, it should be removed while root1 remains", @@ -1293,7 +1292,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityDefault - 1, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-root1": ".", "upstream-root2": ".", "upstream-root3": ".", @@ -1326,7 +1325,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityDefault - 2, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-root1": ".", "upstream-root2": ".", "upstream-root3": ".", @@ -1353,7 +1352,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityMatchDomain, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-group1": "example.com", "upstream-other": "other.com", }, @@ -1392,7 +1391,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { priority: PriorityMatchDomain, }, }, - expectedHandlers: map[string]string{ + expectedHandlers: map[string]domain.Domain{ "upstream-group1": "example.com", "upstream-group2": "example.com", "upstream-other": "other.com", @@ -1448,7 +1447,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { for _, muxEntry := range server.dnsMuxMap { if chainEntry.Handler == muxEntry.handler && chainEntry.Priority == muxEntry.priority && - chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { + chainEntry.Pattern.PunycodeString() == dns.Fqdn(muxEntry.domain.PunycodeString()) { foundInMux = true break } @@ -1467,8 +1466,8 @@ func TestExtraDomains(t *testing.T) { registerDomains []domain.List deregisterDomains []domain.List finalConfig nbdns.Config - expectedDomains []string - expectedMatchOnly []string + expectedDomains domain.List + expectedMatchOnly domain.List applyHostConfigCall int }{ { @@ -1482,12 +1481,12 @@ func TestExtraDomains(t *testing.T) { {Domain: "config.example.com"}, }, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "config.example.com.", "extra1.example.com.", "extra2.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra1.example.com.", "extra2.example.com.", }, @@ -1504,12 +1503,12 @@ func TestExtraDomains(t *testing.T) { registerDomains: []domain.List{ {"extra1.example.com", "extra2.example.com"}, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "config.example.com.", "extra1.example.com.", "extra2.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra1.example.com.", "extra2.example.com.", }, @@ -1527,12 +1526,12 @@ func TestExtraDomains(t *testing.T) { registerDomains: []domain.List{ {"extra.example.com", "overlap.example.com"}, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "config.example.com.", "overlap.example.com.", "extra.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra.example.com.", }, applyHostConfigCall: 2, @@ -1552,12 +1551,12 @@ func TestExtraDomains(t *testing.T) { deregisterDomains: []domain.List{ {"extra1.example.com", "extra3.example.com"}, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "config.example.com.", "extra2.example.com.", "extra4.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra2.example.com.", "extra4.example.com.", }, @@ -1578,13 +1577,13 @@ func TestExtraDomains(t *testing.T) { deregisterDomains: []domain.List{ {"duplicate.example.com"}, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "config.example.com.", "extra.example.com.", "other.example.com.", "duplicate.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra.example.com.", "other.example.com.", "duplicate.example.com.", @@ -1609,13 +1608,13 @@ func TestExtraDomains(t *testing.T) { {Domain: "newconfig.example.com"}, }, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "config.example.com.", "newconfig.example.com.", "extra.example.com.", "duplicate.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra.example.com.", "duplicate.example.com.", }, @@ -1636,12 +1635,12 @@ func TestExtraDomains(t *testing.T) { deregisterDomains: []domain.List{ {"protected.example.com"}, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "extra.example.com.", "config.example.com.", "protected.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "extra.example.com.", }, applyHostConfigCall: 3, @@ -1652,7 +1651,7 @@ func TestExtraDomains(t *testing.T) { ServiceEnable: true, NameServerGroups: []*nbdns.NameServerGroup{ { - Domains: []string{"ns.example.com", "overlap.ns.example.com"}, + Domains: domain.List{"ns.example.com", "overlap.ns.example.com"}, NameServers: []nbdns.NameServer{ { IP: netip.MustParseAddr("8.8.8.8"), @@ -1666,12 +1665,12 @@ func TestExtraDomains(t *testing.T) { registerDomains: []domain.List{ {"extra.example.com", "overlap.ns.example.com"}, }, - expectedDomains: []string{ + expectedDomains: domain.List{ "ns.example.com.", "overlap.ns.example.com.", "extra.example.com.", }, - expectedMatchOnly: []string{ + expectedMatchOnly: domain.List{ "ns.example.com.", "overlap.ns.example.com.", "extra.example.com.", @@ -1742,8 +1741,8 @@ func TestExtraDomains(t *testing.T) { lastConfig := capturedConfigs[len(capturedConfigs)-1] // Check all expected domains are present - domainMap := make(map[string]bool) - matchOnlyMap := make(map[string]bool) + domainMap := make(map[domain.Domain]bool) + matchOnlyMap := make(map[domain.Domain]bool) for _, d := range lastConfig.Domains { domainMap[d.Domain] = true @@ -1860,12 +1859,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { err := server.applyConfiguration(initialConfig) assert.NoError(t, err) - var domains []string + var domains domain.List for _, d := range capturedConfig.Domains { domains = append(domains, d.Domain) } - assert.Contains(t, domains, "config.example.com.") - assert.Contains(t, domains, "extra.example.com.") + assert.Contains(t, domains, domain.Domain("config.example.com.")) + assert.Contains(t, domains, domain.Domain("extra.example.com.")) // Now apply a new configuration with overlapping domain updatedConfig := nbdns.Config{ @@ -1879,7 +1878,7 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { assert.NoError(t, err) // Verify both domains are in config, but no duplicates - domains = []string{} + domains = domain.List{} matchOnlyCount := 0 for _, d := range capturedConfig.Domains { domains = append(domains, d.Domain) @@ -1888,12 +1887,12 @@ func TestUpdateConfigWithExistingExtraDomains(t *testing.T) { } } - assert.Contains(t, domains, "config.example.com.") - assert.Contains(t, domains, "extra.example.com.") + assert.Contains(t, domains, domain.Domain("config.example.com.")) + assert.Contains(t, domains, domain.Domain("extra.example.com.")) assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates") // Extra domain should no longer be marked as match-only when in config - matchOnlyDomain := "" + var matchOnlyDomain domain.Domain for _, d := range capturedConfig.Domains { if d.Domain == "extra.example.com." && d.MatchOnly { matchOnlyDomain = d.Domain @@ -1946,10 +1945,10 @@ func TestDomainCaseHandling(t *testing.T) { err := server.applyConfiguration(config) assert.NoError(t, err) - var domains []string + var domains domain.List for _, d := range capturedConfig.Domains { domains = append(domains, d.Domain) } - assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") - assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") + assert.Contains(t, domains, domain.Domain("config.example.com."), "Mixed case domain should be normalized and pre.sent") + assert.Contains(t, domains, domain.Domain("mixed.example.com."), "Mixed case domain should be normalized and present") } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 9040ed787..6dded6b26 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -117,15 +117,15 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana continue } domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ - Domain: dConf.Domain, + Domain: dConf.Domain.PunycodeString(), MatchOnly: dConf.MatchOnly, }) if dConf.MatchOnly { - matchDomains = append(matchDomains, dConf.Domain) + matchDomains = append(matchDomains, dConf.Domain.PunycodeString()) continue } - searchDomains = append(searchDomains, dConf.Domain) + searchDomains = append(searchDomains, dConf.Domain.PunycodeString()) } if config.RouteAll { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 2fbfb3b91..e71728c83 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/domain" ) const ( @@ -48,7 +49,7 @@ type upstreamResolverBase struct { cancel context.CancelFunc upstreamClient upstreamClient upstreamServers []string - domain string + domain domain.Domain disabled bool failsCount atomic.Int32 successCount atomic.Int32 @@ -62,7 +63,7 @@ type upstreamResolverBase struct { statusRecorder *peer.Status } -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain domain.Domain) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 06ffcba11..7e426c4f8 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -9,6 +9,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/management/domain" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -27,7 +28,7 @@ func newUpstreamResolver( _ *net.IPNet, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, - domain string, + domain domain.Domain, ) (*upstreamResolver, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) c := &upstreamResolver{ diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 9bb5feab0..5f45b696a 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -10,6 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/management/domain" ) type upstreamResolver struct { @@ -23,7 +24,7 @@ func newUpstreamResolver( _ *net.IPNet, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + domain domain.Domain, ) (*upstreamResolver, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) nonIOS := &upstreamResolver{ diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index ca5b31132..3ec478306 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -14,6 +14,7 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/management/domain" ) type upstreamResolverIOS struct { @@ -30,7 +31,7 @@ func newUpstreamResolver( net *net.IPNet, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + domain domain.Domain, ) (*upstreamResolverIOS, error) { upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) diff --git a/client/internal/engine.go b/client/internal/engine.go index 7c501e5aa..60e05afb6 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1159,7 +1159,7 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.C for _, nsGroup := range protoDNSConfig.GetNameServerGroups() { dnsNSGroup := &nbdns.NameServerGroup{ Primary: nsGroup.GetPrimary(), - Domains: nsGroup.GetDomains(), + Domains: domain.FromPunycodeList(nsGroup.GetDomains()), SearchDomainsEnabled: nsGroup.GetSearchDomainsEnabled(), } for _, ns := range nsGroup.GetNameServers() { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 422059bd8..055b97bdc 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -44,6 +44,7 @@ import ( "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" mgmt "github.com/netbirdio/netbird/management/client" + "github.com/netbirdio/netbird/management/domain" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" @@ -77,7 +78,7 @@ var ( type MockWGIface struct { CreateFunc func() error - CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error + CreateOnAndroidFunc func(routeRange []string, ip string, domains domain.List) error IsUserspaceBindFunc func() bool NameFunc func() string AddressFunc func() wgaddr.Address @@ -107,7 +108,7 @@ func (m *MockWGIface) Create() error { return m.CreateFunc() } -func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { +func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains domain.List) error { return m.CreateOnAndroidFunc(routeRange, ip, domains) } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index e1761ff84..3e7ed86e6 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -14,11 +14,12 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/management/domain" ) type wgIfaceBase interface { Create() error - CreateOnAndroid(routeRange []string, ip string, domains []string) error + CreateOnAndroid(routeRange []string, ip string, domains domain.List) error IsUserspaceBind() bool Name() string Address() wgaddr.Address diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 6d51c88c0..f080ff7d2 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -229,15 +229,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { } if len(r.Answer) > 0 && len(r.Question) > 0 { - origPattern := "" + var origPattern domain.Domain if writer, ok := w.(*nbdns.ResponseWriterChain); ok { origPattern = writer.GetOrigPattern() } resolvedDomain := domain.Domain(strings.ToLower(r.Question[0].Name)) - // already punycode via RegisterHandler() - originalDomain := domain.Domain(origPattern) + originalDomain := origPattern if originalDomain == "" { originalDomain = resolvedDomain } diff --git a/management/domain/domain.go b/management/domain/domain.go index 97acec688..16b93cf67 100644 --- a/management/domain/domain.go +++ b/management/domain/domain.go @@ -30,7 +30,7 @@ func (d Domain) SafeString() string { } // PunycodeString returns the punycode representation of the Domain. -// This should only be used if a punycode domain is expected but only a string is supported. +// This should only be used if a punycode domain is expected but only a string is supported (e.g. an external library). func (d Domain) PunycodeString() string { return string(d) } diff --git a/management/domain/list.go b/management/domain/list.go index a988f4f70..8d93dd24f 100644 --- a/management/domain/list.go +++ b/management/domain/list.go @@ -1,7 +1,7 @@ package domain import ( - "sort" + "slices" "strings" ) @@ -41,6 +41,7 @@ func (d List) ToSafeStringList() []string { } // String converts List to a comma-separated string. +// This is useful for displaying domain names in a user-friendly format. func (d List) String() (string, error) { list, err := d.ToStringList() if err != nil { @@ -50,7 +51,8 @@ func (d List) String() (string, error) { } // SafeString converts List to a comma-separated non-punycode string. -// If a domain cannot be converted, the original string is used. +// This is useful for displaying domain names in a user-friendly format. +// If a domain cannot be converted, the original (punycode) string is used. func (d List) SafeString() string { str, err := d.String() if err != nil { @@ -64,28 +66,22 @@ func (d List) PunycodeString() string { return strings.Join(d.ToPunycodeList(), ", ") } +// Equal checks if two domain lists are equal without considering the order. func (d List) Equal(domains List) bool { if len(d) != len(domains) { return false } - sort.Slice(d, func(i, j int) bool { - return d[i] < d[j] - }) + d1 := slices.Clone(d) + d2 := slices.Clone(domains) - sort.Slice(domains, func(i, j int) bool { - return domains[i] < domains[j] - }) + slices.Sort(d1) + slices.Sort(d2) - for i, domain := range d { - if domain != domains[i] { - return false - } - } - return true + return slices.Equal(d1, d2) } -// FromStringList creates a DomainList from a slice of string. +// FromStringList creates a List from a slice of strings. func FromStringList(s []string) (List, error) { var dl List for _, domain := range s {