diff --git a/client/embed/embed.go b/client/embed/embed.go index e266aae28..e73f37e35 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -69,6 +69,8 @@ type Options struct { StatePath string // DisableClientRoutes disables the client routes DisableClientRoutes bool + // BlockInbound blocks all inbound connections from peers + BlockInbound bool } // validateCredentials checks that exactly one credential type is provided @@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) { PreSharedKey: &opts.PreSharedKey, DisableServerRoutes: &t, DisableClientRoutes: &opts.DisableClientRoutes, + BlockInbound: &opts.BlockInbound, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 5458519fa..1b1a8ce1c 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -8,8 +8,6 @@ import ( "net" "sync" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/hashicorp/go-multierror" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" @@ -26,16 +24,6 @@ const ( loopbackAddr = "127.0.0.1" ) -var ( - localHostNetIPv4 = net.ParseIP("127.0.0.1") - localHostNetIPv6 = net.ParseIP("::1") - - serializeOpts = gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } -) - // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -253,63 +241,3 @@ generatePort: } return p.lastUsedPort, nil } - -func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { - - var ipH gopacket.SerializableLayer - var networkLayer gopacket.NetworkLayer - var dstIP net.IP - var rawConn net.PacketConn - - if endpointAddr.IP.To4() != nil { - // IPv4 path - ipv4 := &layers.IPv4{ - DstIP: localHostNetIPv4, - SrcIP: endpointAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - } - ipH = ipv4 - networkLayer = ipv4 - dstIP = localHostNetIPv4 - rawConn = p.rawConnIPv4 - } else { - // IPv6 path - if p.rawConnIPv6 == nil { - return fmt.Errorf("IPv6 raw socket not available") - } - ipv6 := &layers.IPv6{ - DstIP: localHostNetIPv6, - SrcIP: endpointAddr.IP, - Version: 6, - HopLimit: 64, - NextHeader: layers.IPProtocolUDP, - } - ipH = ipv6 - networkLayer = ipv6 - dstIP = localHostNetIPv6 - rawConn = p.rawConnIPv6 - } - - udpH := &layers.UDP{ - SrcPort: layers.UDPPort(endpointAddr.Port), - DstPort: layers.UDPPort(p.localWGListenPort), - } - - if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { - return fmt.Errorf("set network layer for checksum: %w", err) - } - - layerBuffer := gopacket.NewSerializeBuffer() - payload := gopacket.Payload(data) - - if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil { - return fmt.Errorf("serialize layers: %w", err) - } - - if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil { - return fmt.Errorf("write to raw conn: %w", err) - } - return nil -} diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 5b98be7b4..6e80945c4 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -10,12 +10,89 @@ import ( "net" "sync" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +var ( + errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available") + errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available") + + localHostNetIPv4 = net.ParseIP("127.0.0.1") + localHostNetIPv6 = net.ParseIP("::1") + + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } +) + +// PacketHeaders holds pre-created headers and buffers for efficient packet sending +type PacketHeaders struct { + ipH gopacket.SerializableLayer + udpH *layers.UDP + layerBuffer gopacket.SerializeBuffer + localHostAddr net.IP + isIPv4 bool +} + +func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) { + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + var localHostAddr net.IP + var isIPv4 bool + + // Check if source address is IPv4 or IPv6 + if endpoint.IP.To4() != nil { + // IPv4 path + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPv4, + SrcIP: endpoint.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + localHostAddr = localHostNetIPv4 + isIPv4 = true + } else { + // IPv6 path + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPv6, + SrcIP: endpoint.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 + localHostAddr = localHostNetIPv6 + isIPv4 = false + } + + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(endpoint.Port), + DstPort: layers.UDPPort(localWGListenPort), + } + + if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { + return nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return &PacketHeaders{ + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + localHostAddr: localHostAddr, + isIPv4: isIPv4, + }, nil +} + // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { wgeBPFProxy *WGEBPFProxy @@ -24,8 +101,10 @@ type ProxyWrapper struct { ctx context.Context cancel context.CancelFunc - wgRelayedEndpointAddr *net.UDPAddr - wgEndpointCurrentUsedAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + headers *PacketHeaders + headerCurrentUsed *PacketHeaders + rawConn net.PacketConn paused bool pausedCond *sync.Cond @@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { closeListener: listener.NewCloseListener(), } } + func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } + + headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr) + if err != nil { + return fmt.Errorf("create packet sender: %w", err) + } + + // Check if required raw connection is available + if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil { + return errIPv6ConnNotAvailable + } + if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil { + return errIPv4ConnNotAvailable + } + p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) p.wgRelayedEndpointAddr = addr - return err + p.headers = headers + p.rawConn = p.selectRawConn(headers) + return nil } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { @@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() { p.pausedCond.L.Lock() p.paused = false - p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr + p.headerCurrentUsed = p.headers + p.rawConn = p.selectRawConn(p.headerCurrentUsed) if !p.isStarted { p.isStarted = true @@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { log.Errorf("failed to start package redirection, endpoint is nil") return } + + header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create packet headers: %s", err) + return + } + + // Check if required raw connection is available + if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil { + log.Error(errIPv6ConnNotAvailable) + return + } + if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil { + log.Error(errIPv4ConnNotAvailable) + return + } + p.pausedCond.L.Lock() p.paused = false - p.wgEndpointCurrentUsedAddr = endpoint + p.headerCurrentUsed = header + p.rawConn = p.selectRawConn(header) p.pausedCond.Signal() p.pausedCond.L.Unlock() @@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { p.pausedCond.Wait() } - err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + err = p.sendPkg(buf[:n], p.headerCurrentUsed) p.pausedCond.L.Unlock() if err != nil { @@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } return n, nil } + +func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error { + defer func() { + if err := header.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil { + return fmt.Errorf("serialize layers: %w", err) + } + + if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil { + return fmt.Errorf("write to raw conn: %w", err) + } + return nil +} + +func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn { + if header.isIPv4 { + return p.wgeBPFProxy.rawConnIPv4 + } + return p.wgeBPFProxy.rawConnIPv6 +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 72c0004d5..fa9525069 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { matchSubdomains: false, shouldMatch: false, }, + { + name: "single letter TLD exact match", + handlerDomain: "example.x.", + queryDomain: "example.x.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "single letter TLD subdomain match", + handlerDomain: "example.x.", + queryDomain: "sub.example.x.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, + { + name: "single letter TLD wildcard match", + handlerDomain: "*.example.x.", + queryDomain: "sub.example.x.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "two letter domain labels", + handlerDomain: "a.b.", + queryDomain: "a.b.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "single character domain", + handlerDomain: "x.", + queryDomain: "x.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "single character domain with subdomain match", + handlerDomain: "x.", + queryDomain: "sub.x.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, } for _, tt := range tests { diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 71badf0d4..af84c8a85 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -9,8 +9,10 @@ import ( "io" "net/netip" "os/exec" + "slices" "strconv" "strings" + "sync" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -38,6 +40,9 @@ const ( type systemConfigurator struct { createdKeys map[string]struct{} systemDNSSettings SystemDNSSettings + + mu sync.RWMutex + origNameservers []netip.Addr } func newHostManager() (*systemConfigurator, error) { @@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } var dnsSettings SystemDNSSettings + var serverAddresses []netip.Addr inSearchDomainsArray := false inServerAddressesArray := false @@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] - if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { - dnsSettings.ServerIP = ip.Unmap() - inServerAddressesArray = false // Stop reading after finding the first IPv4 address + if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() { + ip = ip.Unmap() + serverAddresses = append(serverAddresses, ip) + if !dnsSettings.ServerIP.IsValid() && ip.Is4() { + dnsSettings.ServerIP = ip + } } } } @@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { // default to 53 port dnsSettings.ServerPort = DefaultPort + s.mu.Lock() + s.origNameservers = serverAddresses + s.mu.Unlock() + return dnsSettings, nil } +func (s *systemConfigurator) getOriginalNameservers() []netip.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.origNameservers) +} + func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go index c4efd17b0..28915de65 100644 --- a/client/internal/dns/host_darwin_test.go +++ b/client/internal/dns/host_darwin_test.go @@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error { _, err := cmd.CombinedOutput() return err } + +func TestGetOriginalNameservers(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + origNameservers: []netip.Addr{ + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("1.1.1.1"), + }, + } + + servers := configurator.getOriginalNameservers() + assert.Len(t, servers, 2) + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0]) + assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1]) +} + +func TestGetOriginalNameserversFromSystem(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + + servers := configurator.getOriginalNameservers() + + require.NotEmpty(t, servers, "expected at least one DNS server from system configuration") + + for _, server := range servers { + assert.True(t, server.IsValid(), "server address should be valid") + assert.False(t, server.IsUnspecified(), "server address should not be unspecified") + } + + t.Logf("found %d original nameservers: %v", len(servers), servers) +} + +func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) { + t.Helper() + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + cleanup := func() { + _ = sm.Stop(context.Background()) + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + } + + return configurator, sm, cleanup +} + +func TestOriginalNameserversNoTransition(t *testing.T) { + netbirdIP := netip.MustParseAddr("100.64.0.1") + + testCases := []struct { + name string + routeAll bool + }{ + {"routeall_false", false}, + {"routeall_true", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator, sm, cleanup := setupTestConfigurator(t) + defer cleanup() + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + initialServers := configurator.getOriginalNameservers() + t.Logf("Initial servers: %v", initialServers) + require.NotEmpty(t, initialServers) + + for _, srv := range initialServers { + require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP") + } + + config := HostDNSConfig{ + ServerIP: netbirdIP, + ServerPort: 53, + RouteAll: tc.routeAll, + Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, + } + + for i := 1; i <= 2; i++ { + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + servers := configurator.getOriginalNameservers() + t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers) + assert.Equal(t, initialServers, servers) + } + }) + } +} + +func TestOriginalNameserversRouteAllTransition(t *testing.T) { + netbirdIP := netip.MustParseAddr("100.64.0.1") + + testCases := []struct { + name string + initialRoute bool + }{ + {"start_with_routeall_false", false}, + {"start_with_routeall_true", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator, sm, cleanup := setupTestConfigurator(t) + defer cleanup() + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + initialServers := configurator.getOriginalNameservers() + t.Logf("Initial servers: %v", initialServers) + require.NotEmpty(t, initialServers) + + config := HostDNSConfig{ + ServerIP: netbirdIP, + ServerPort: 53, + RouteAll: tc.initialRoute, + Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, + } + + // First apply + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers := configurator.getOriginalNameservers() + t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers) + assert.Equal(t, initialServers, servers) + + // Toggle RouteAll + config.RouteAll = !tc.initialRoute + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers = configurator.getOriginalNameservers() + t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers) + assert.Equal(t, initialServers, servers) + + // Toggle back + config.RouteAll = tc.initialRoute + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers = configurator.getOriginalNameservers() + t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers) + assert.Equal(t, initialServers, servers) + + for _, srv := range servers { + assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP") + } + }) + } +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1ce7bf1c6..4d4fcc06e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() { s.registerFallback(config) } -// registerFallback registers original nameservers as low-priority fallback handlers +// registerFallback registers original nameservers as low-priority fallback handlers. func (s *DefaultServer) registerFallback(config HostDNSConfig) { hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) if !ok { @@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { originalNameservers := hostMgrWithNS.getOriginalNameservers() if len(originalNameservers) == 0 { + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) return } diff --git a/client/internal/dns/test/mock.go b/client/internal/dns/test/mock.go index 1db452805..8d16689bf 100644 --- a/client/internal/dns/test/mock.go +++ b/client/internal/dns/test/mock.go @@ -8,15 +8,21 @@ import ( type MockResponseWriter struct { WriteMsgFunc func(m *dns.Msg) error + lastResponse *dns.Msg } func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { + rw.lastResponse = m if rw.WriteMsgFunc != nil { return rw.WriteMsgFunc(m) } return nil } +func (rw *MockResponseWriter) GetLastResponse() *dns.Msg { + return rw.lastResponse +} + func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index f0693e82c..63ba1c9f2 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -573,9 +573,11 @@ func (e *Engine) createFirewall() error { var err error e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) - if err != nil || e.firewall == nil { - log.Errorf("failed creating firewall manager: %s", err) - return nil + if err != nil { + return fmt.Errorf("create firewall manager: %w", err) + } + if e.firewall == nil { + return fmt.Errorf("create firewall manager: received nil manager") } if err := e.initFirewall(); err != nil { diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 6d019258d..6dd81f68c 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -14,6 +14,7 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) @@ -37,6 +38,11 @@ func New() *NetworkMonitor { // Listen begins monitoring network changes. When a change is detected, this function will return without error. func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { + if netstack.IsEnabled() { + log.Debugf("Network monitor: skipping in netstack mode") + return nil + } + nw.mu.Lock() if nw.cancel != nil { nw.mu.Unlock() diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 39133a6d3..eb455431d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn } conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + conn.enableWgWatcherIfNeeded() + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) @@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn conn.wgProxyRelay.RedirectAs(ep) } - conn.enableWgWatcherIfNeeded() - conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { wgProxy.Work() presharedKey := conn.presharedKey(rci.rosenpassPubKey) + + conn.enableWgWatcherIfNeeded() + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) @@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { return } - conn.enableWgWatcherIfNeeded() - wgConfigWorkaround() conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.currentConnPriority = conntype.Relay diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go index 78d70c15b..a870c1145 100644 --- a/client/internal/wg_iface_monitor.go +++ b/client/internal/wg_iface_monitor.go @@ -9,6 +9,8 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" ) // WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine @@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes return false, errors.New("not supported on mobile platforms") } + if netstack.IsEnabled() { + log.Debugf("Interface monitor: skipped in netstack mode") + return false, nil + } + if ifaceName == "" { log.Debugf("Interface monitor: empty interface name, skipping monitor") return false, errors.New("empty interface name") diff --git a/management/cmd/management.go b/management/cmd/management.go index 7da04074b..511168823 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -16,13 +16,13 @@ import ( "strings" "syscall" - "github.com/miekg/dns" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/server" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util/crypt" ) @@ -78,9 +78,8 @@ var ( } } - _, valid := dns.IsDomainName(dnsDomain) - if !valid || len(dnsDomain) > 192 { - return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain)) + if !nbdomain.IsValidDomainNoWildcard(dnsDomain) { + return fmt.Errorf("invalid dns-domain: %s", dnsDomain) } return nil diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go index 15119045b..758f643d0 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } for accountID, peerIDs := range peerIDsPerAccount { - log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) + log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID) err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { - log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) + log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err) } } } diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 1551689b4..7ac2e379f 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -108,10 +108,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { + if e, ok := status.FromError(err); ok && e.Type() == status.NotFound { + log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID) + return nil + } return err } if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { + log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)", + peerID, peer.Status.Connected, + peer.Status.LastSeen.Format(time.RFC3339), + time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339), + peer.Ephemeral) return nil } @@ -150,7 +159,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs return nil }) if err != nil { - return err + log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err) + continue } if m.integratedPeerValidator != nil { diff --git a/management/internals/modules/zones/records/record.go b/management/internals/modules/zones/records/record.go index e44de08f4..1488febb9 100644 --- a/management/internals/modules/zones/records/record.go +++ b/management/internals/modules/zones/records/record.go @@ -6,7 +6,7 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -63,7 +63,7 @@ func (r *Record) Validate() error { return errors.New("record name is required") } - if !util.IsValidDomain(r.Name) { + if !domain.IsValidDomain(r.Name) { return errors.New("invalid record name format") } @@ -81,8 +81,8 @@ func (r *Record) Validate() error { return err } case RecordTypeCNAME: - if !util.IsValidDomain(r.Content) { - return errors.New("invalid CNAME record format") + if !domain.IsValidDomainNoWildcard(r.Content) { + return errors.New("invalid CNAME target format") } default: return errors.New("invalid record type, must be A, AAAA, or CNAME") diff --git a/management/internals/modules/zones/zone.go b/management/internals/modules/zones/zone.go index 27adac1ac..f5ebed26c 100644 --- a/management/internals/modules/zones/zone.go +++ b/management/internals/modules/zones/zone.go @@ -6,7 +6,7 @@ import ( "github.com/rs/xid" "github.com/netbirdio/netbird/management/internals/modules/zones/records" - "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -73,7 +73,7 @@ func (z *Zone) Validate() error { return errors.New("zone name exceeds maximum length of 255 characters") } - if !util.IsValidDomain(z.Domain) { + if !domain.IsValidDomainNoWildcard(z.Domain) { return errors.New("invalid zone domain format") } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 32049d044..219baaf6d 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -17,13 +17,14 @@ import ( pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" - "github.com/netbirdio/netbird/shared/management/client/common" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/shared/management/client/common" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/idp" @@ -304,6 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) + s.cancelPeerRoutines(ctx, accountID, peer) return err } diff --git a/management/server/account.go b/management/server/account.go index d453b87c3..ba5f0cffa 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,6 +26,7 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -231,7 +232,7 @@ func BuildManager( // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { - if !isDomainValid(singleAccountModeDomain) { + if !nbdomain.IsValidDomainNoWildcard(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain @@ -402,7 +403,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + if newSettings.DNSDomain != "" && !nbdomain.IsValidDomainNoWildcard(newSettings.DNSDomain) { return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } @@ -1691,10 +1692,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return nil } -var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) +// isDomainValid validates public/IDP domains using stricter rules than internal DNS domains. +// Requires at least 2-char alphabetic TLD and no single-label domains. +var publicDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { - return invalidDomainRegexp.MatchString(domain) + return publicDomainRegexp.MatchString(domain) } func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index a3eb4ae2e..3d8c78912 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,10 +3,10 @@ package server import ( "context" "errors" - "regexp" + "fmt" + "strings" "unicode/utf8" - "github.com/miekg/dns" "github.com/rs/xid" nbdns "github.com/netbirdio/netbird/dns" @@ -15,11 +15,10 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) -const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` - var errInvalidDomainName = errors.New("invalid domain name") // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs @@ -305,16 +304,18 @@ func validateGroups(list []string, groups map[string]*types.Group) error { return nil } -var domainMatcher = regexp.MustCompile(domainPattern) - -func validateDomain(domain string) error { - if !domainMatcher.MatchString(domain) { - return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") +// validateDomain validates a nameserver match domain. +// Converts unicode to punycode. Wildcards are not allowed for nameservers. +func validateDomain(d string) error { + if strings.HasPrefix(d, "*.") { + return errors.New("wildcards not allowed") } - _, valid := dns.IsDomainName(domain) - if !valid { - return errInvalidDomainName + // Nameservers allow trailing dot (FQDN format) + toValidate := strings.TrimSuffix(d, ".") + + if _, err := nbdomain.ValidateDomains([]string{toValidate}); err != nil { + return fmt.Errorf("%w: %w", errInvalidDomainName, err) } return nil diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 0d781e0d4..90b4b9687 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -901,82 +901,53 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, return account, nil } +// TestValidateDomain tests nameserver-specific domain validation. +// Core domain validation is tested in shared/management/domain/validate_test.go. +// This test only covers nameserver-specific behavior: wildcard rejection and unicode support. func TestValidateDomain(t *testing.T) { testCases := []struct { name string domain string errFunc require.ErrorAssertionFunc }{ + // Nameserver-specific: wildcards not allowed { - name: "Valid domain name with multiple labels", - domain: "123.example.com", + name: "Wildcard prefix rejected", + domain: "*.example.com", + errFunc: require.Error, + }, + { + name: "Wildcard in middle rejected", + domain: "a.*.example.com", + errFunc: require.Error, + }, + // Nameserver-specific: unicode converted to punycode + { + name: "Unicode domain converted to punycode", + domain: "münchen.de", errFunc: require.NoError, }, { - name: "Valid domain name with hyphen", - domain: "test-example.com", + name: "Unicode domain all labels", + domain: "中国.中国", + errFunc: require.NoError, + }, + // Basic validation still works (delegates to shared validation) + { + name: "Valid multi-label domain", + domain: "example.com", errFunc: require.NoError, }, { - name: "Valid domain name with only one label", - domain: "example", + name: "Valid single label", + domain: "internal", errFunc: require.NoError, }, { - name: "Valid domain name with trailing dot", - domain: "example.", - errFunc: require.NoError, - }, - { - name: "Invalid wildcard domain name", - domain: "*.example", - errFunc: require.Error, - }, - { - name: "Invalid domain name with leading dot", - domain: ".com", - errFunc: require.Error, - }, - { - name: "Invalid domain name with dot only", - domain: ".", - errFunc: require.Error, - }, - { - name: "Invalid domain name with double hyphen", - domain: "test--example.com", - errFunc: require.Error, - }, - { - name: "Invalid domain name with a label exceeding 63 characters", - domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com", - errFunc: require.Error, - }, - { - name: "Invalid domain name starting with a hyphen", + name: "Invalid leading hyphen", domain: "-example.com", errFunc: require.Error, }, - { - name: "Invalid domain name ending with a hyphen", - domain: "example.com-", - errFunc: require.Error, - }, - { - name: "Invalid domain with unicode", - domain: "example?,.com", - errFunc: require.Error, - }, - { - name: "Invalid domain with space before top-level domain", - domain: "space .example.com", - errFunc: require.Error, - }, - { - name: "Invalid domain with trailing space", - domain: "example.com ", - errFunc: require.Error, - }, } for _, testCase := range testCases { diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index e2dea2c6b..29b0af2cc 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -203,7 +203,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { NetworkID: "testNetworkId", Name: "testResourceId", Description: "description", - Address: "invalid-address", + Address: "-invalid", } store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) @@ -227,9 +227,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { resource := &types.NetworkResource{ AccountID: "testAccountId", NetworkID: "testNetworkId", - Name: "testResourceId", + Name: "used-name", Description: "description", - Address: "invalid-address", + Address: "example.com", } store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 6b8cf9412..1fa908393 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/netip" - "regexp" "github.com/rs/xid" @@ -166,8 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil } - domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) - if domainRegex.MatchString(address) { + if _, err := nbDomain.ValidateDomains([]string{address}); err == nil { return Domain, address, netip.Prefix{}, nil } diff --git a/management/server/networks/resources/types/resource_test.go b/management/server/networks/resources/types/resource_test.go index 02e802300..a842b0a28 100644 --- a/management/server/networks/resources/types/resource_test.go +++ b/management/server/networks/resources/types/resource_test.go @@ -23,10 +23,12 @@ func TestGetResourceType(t *testing.T) { {"example.com", Domain, false, "example.com", netip.Prefix{}}, {"*.example.com", Domain, false, "*.example.com", netip.Prefix{}}, {"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}}, + {"example.x", Domain, false, "example.x", netip.Prefix{}}, + {"internal", Domain, false, "internal", netip.Prefix{}}, // Invalid inputs - {"invalid", "", true, "", netip.Prefix{}}, {"1.1.1.1/abc", "", true, "", netip.Prefix{}}, - {"1234", "", true, "", netip.Prefix{}}, + {"-invalid.com", "", true, "", netip.Prefix{}}, + {"", "", true, "", netip.Prefix{}}, } for _, tt := range tests { diff --git a/management/server/peer.go b/management/server/peer.go index 80c74e209..ab72d3051 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -728,11 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } - if temporary { - // we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected - am.networkMapController.TrackEphemeralPeer(ctx, newPeer) - } - if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -760,6 +755,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed to increment network serial: %w", err) } + if ephemeral { + // we should track ephemeral peers to be able to clean them if the peer doesn't sync and isn't marked as connected + am.networkMapController.TrackEphemeralPeer(ctx, newPeer) + } + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) return nil }) diff --git a/management/server/util/util.go b/management/server/util/util.go index eea6a72b0..ce9759864 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -1,9 +1,5 @@ package util -import "regexp" - -var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) - // Difference returns the elements in `a` that aren't in `b`. func Difference(a, b []string) []string { mb := make(map[string]struct{}, len(b)) @@ -55,9 +51,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool { return false } -func IsValidDomain(domain string) bool { - if domain == "" { - return false - } - return domainRegex.MatchString(domain) -} diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index bf2af7116..1858b5d55 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -10,7 +10,30 @@ const maxDomains = 32 var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) -// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. +// IsValidDomain checks if a single domain string is valid. +// Does not convert unicode to punycode - domain must already be ASCII/punycode. +// Allows wildcard prefix (*.example.com). +func IsValidDomain(domain string) bool { + if domain == "" { + return false + } + return domainRegex.MatchString(strings.ToLower(domain)) +} + +// IsValidDomainNoWildcard checks if a single domain string is valid without wildcard prefix. +// Use for zone domains and CNAME targets where wildcards are not allowed. +func IsValidDomainNoWildcard(domain string) bool { + if domain == "" { + return false + } + if strings.HasPrefix(domain, "*.") { + return false + } + return domainRegex.MatchString(strings.ToLower(domain)) +} + +// ValidateDomains validates domains and converts unicode to punycode. +// Allows wildcard prefix (*.example.com). Maximum 32 domains. func ValidateDomains(domains []string) (List, error) { if len(domains) == 0 { return nil, fmt.Errorf("domains list is empty") @@ -37,7 +60,10 @@ func ValidateDomains(domains []string) (List, error) { return domainList, nil } -// ValidateDomainsList checks if each domain in the list is valid +// ValidateDomainsList validates domains without punycode conversion. +// Use this for domains that must already be in ASCII/punycode format (e.g., extra DNS labels). +// Unlike ValidateDomains, this does not convert unicode to punycode - unicode domains will fail. +// Allows wildcard prefix (*.example.com). Maximum 32 domains. func ValidateDomainsList(domains []string) error { if len(domains) == 0 { return nil diff --git a/shared/management/domain/validate_test.go b/shared/management/domain/validate_test.go index 30efcd9a9..9dbcd8ac8 100644 --- a/shared/management/domain/validate_test.go +++ b/shared/management/domain/validate_test.go @@ -2,12 +2,16 @@ package domain import ( "fmt" + "strings" "testing" "github.com/stretchr/testify/assert" ) func TestValidateDomains(t *testing.T) { + label63 := strings.Repeat("a", 63) + label64 := strings.Repeat("a", 64) + tests := []struct { name string domains []string @@ -26,6 +30,48 @@ func TestValidateDomains(t *testing.T) { expected: List{"sub.ex-ample.com"}, wantErr: false, }, + { + name: "Valid uppercase domain normalized to lowercase", + domains: []string{"EXAMPLE.COM"}, + expected: List{"example.com"}, + wantErr: false, + }, + { + name: "Valid mixed case domain", + domains: []string{"ExAmPlE.CoM"}, + expected: List{"example.com"}, + wantErr: false, + }, + { + name: "Single letter TLD", + domains: []string{"example.x"}, + expected: List{"example.x"}, + wantErr: false, + }, + { + name: "Two letter domain labels", + domains: []string{"a.b"}, + expected: List{"a.b"}, + wantErr: false, + }, + { + name: "Single character domain", + domains: []string{"x"}, + expected: List{"x"}, + wantErr: false, + }, + { + name: "Wildcard with single letter TLD", + domains: []string{"*.x"}, + expected: List{"*.x"}, + wantErr: false, + }, + { + name: "Multi-level with single letter labels", + domains: []string{"a.b.c"}, + expected: List{"a.b.c"}, + wantErr: false, + }, { name: "Valid Unicode domain", domains: []string{"münchen.de"}, @@ -45,17 +91,92 @@ func TestValidateDomains(t *testing.T) { wantErr: false, }, { - name: "Invalid domain format", + name: "Valid domain starting with digit", + domains: []string{"123.example.com"}, + expected: List{"123.example.com"}, + wantErr: false, + }, + // Numeric TLDs are allowed for internal/private DNS use cases. + // While ICANN doesn't issue all-numeric gTLDs, the DNS protocol permits them + // and resolvers like systemd-resolved handle them correctly. + { + name: "Numeric TLD allowed", + domains: []string{"example.123"}, + expected: List{"example.123"}, + wantErr: false, + }, + { + name: "Single digit TLD allowed", + domains: []string{"example.1"}, + expected: List{"example.1"}, + wantErr: false, + }, + { + name: "All numeric labels allowed", + domains: []string{"123.456"}, + expected: List{"123.456"}, + wantErr: false, + }, + { + name: "Single numeric label allowed", + domains: []string{"123"}, + expected: List{"123"}, + wantErr: false, + }, + { + name: "Valid domain with double hyphen", + domains: []string{"test--example.com"}, + expected: List{"test--example.com"}, + wantErr: false, + }, + { + name: "Invalid leading hyphen", domains: []string{"-example.com"}, expected: nil, wantErr: true, }, { - name: "Invalid domain format 2", + name: "Invalid trailing hyphen", domains: []string{"example.com-"}, expected: nil, wantErr: true, }, + { + name: "Invalid leading dot", + domains: []string{".com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid dot only", + domains: []string{"."}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid double dot", + domains: []string{"example..com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid special characters", + domains: []string{"example?,.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid space in domain", + domains: []string{"space .example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid trailing space", + domains: []string{"example.com "}, + expected: nil, + wantErr: true, + }, { name: "Multiple domains valid and invalid", domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, @@ -86,6 +207,30 @@ func TestValidateDomains(t *testing.T) { expected: nil, wantErr: true, }, + { + name: "Valid 63 char label (max)", + domains: []string{label63 + ".com"}, + expected: List{Domain(label63 + ".com")}, + wantErr: false, + }, + { + name: "Invalid 64 char label (exceeds max)", + domains: []string{label64 + ".com"}, + expected: nil, + wantErr: true, + }, + { + name: "Valid 253 char domain (max)", + domains: []string{strings.Repeat("a.", 126) + "a"}, + expected: List{Domain(strings.Repeat("a.", 126) + "a")}, + wantErr: false, + }, + { + name: "Invalid 254+ char domain (exceeds max)", + domains: []string{strings.Repeat("ab.", 85)}, + expected: nil, + wantErr: true, + }, } for _, tt := range tests { @@ -118,6 +263,57 @@ func TestValidateDomainsList(t *testing.T) { domains: []string{"sub.ex-ample.com"}, wantErr: false, }, + { + name: "Uppercase domain accepted", + domains: []string{"EXAMPLE.COM"}, + wantErr: false, + }, + { + name: "Single letter TLD", + domains: []string{"example.x"}, + wantErr: false, + }, + { + name: "Two letter domain labels", + domains: []string{"a.b"}, + wantErr: false, + }, + { + name: "Single character domain", + domains: []string{"x"}, + wantErr: false, + }, + { + name: "Wildcard with single letter TLD", + domains: []string{"*.x"}, + wantErr: false, + }, + { + name: "Multi-level with single letter labels", + domains: []string{"a.b.c"}, + wantErr: false, + }, + // Numeric TLDs are allowed for internal/private DNS use cases. + { + name: "Numeric TLD allowed", + domains: []string{"example.123"}, + wantErr: false, + }, + { + name: "Single digit TLD allowed", + domains: []string{"example.1"}, + wantErr: false, + }, + { + name: "All numeric labels allowed", + domains: []string{"123.456"}, + wantErr: false, + }, + { + name: "Single numeric label allowed", + domains: []string{"123"}, + wantErr: false, + }, { name: "Underscores in labels", domains: []string{"_jabber._tcp.gmail.com"},