Compare commits

...

12 Commits

Author SHA1 Message Date
Zoltán Papp
85e4afe100 Refactor route manager to use sync.Once for wgInterface access
- Added `toInterface` helper to ensure `wgInterface.ToInterface()` is called once.
- Updated route addition/removal methods to use `toInterface` for efficiency.
2026-02-03 15:26:37 +01:00
Zoltán Papp
a72e0e67e5 Optimize Windows DNS performance with domain batching and batch mode
Implement two-layer optimization to reduce Windows NRPT registry operations:

1. Domain Batching (host_windows.go):
  - Batch up to 50 domains per NRPT rule (Windows undocumented limit)
  - Reduces NRPT rules by ~97% (e.g., 184 domains: 184 rules → 4 rules)
  - Modified addDNSMatchPolicy() to create batched NRPT entries
  - Added comprehensive tests in host_windows_test.go

2. Batch Mode (server.go):
  - Added BeginBatch/EndBatch methods to defer DNS updates
  - Modified RegisterHandler/DeregisterHandler to skip applyHostConfig in batch mode
  - Protected all applyHostConfig() calls with batch mode checks
  - Updated route manager to wrap route operations with batch calls
2026-02-03 00:15:04 +01:00
Zoltán Papp
a33f60e3fd Adds timing measurement to handleSync to help diagnose sync performance issues 2026-02-01 00:12:11 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
Zoltan Papp
5333e55a81 Fix WG watcher missing initial handshake (#5213)
Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp.

Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup.
2026-01-29 16:58:10 +01:00
Viktor Liu
81c11df103 [management] Streamline domain validation (#5211) 2026-01-29 13:51:44 +01:00
Viktor Liu
f74bc48d16 [Client] Stop NetBird on firewall init failure (#5208) 2026-01-29 11:05:06 +01:00
Vlad
0169e4540f [management] fix skip of ephemeral peers on deletion (#5206) 2026-01-29 10:58:45 +01:00
Vlad
cead3f38ee [management] fix ephemeral peers being not removed (#5203) 2026-01-28 18:24:12 +01:00
Zoltan Papp
b55262d4a2 [client] Refactor/optimise raw socket headers (#5174)
Pre-create and reuse packet headers to eliminate per-packet allocations.
2026-01-28 15:06:59 +01:00
Zoltan Papp
2248ff392f Remove redundant square bracket trimming in USP endpoint parsing (#5197) 2026-01-27 20:10:59 +01:00
34 changed files with 931 additions and 223 deletions

View File

@@ -69,6 +69,8 @@ type Options struct {
StatePath string StatePath string
// DisableClientRoutes disables the client routes // DisableClientRoutes disables the client routes
DisableClientRoutes bool DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
} }
// validateCredentials checks that exactly one credential type is provided // validateCredentials checks that exactly one credential type is provided
@@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey, PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t, DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
} }
if opts.ConfigPath != "" { if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input) config, err = profilemanager.UpdateOrCreateConfig(input)

View File

@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue continue
} }
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]")) host, portStr, err := net.SplitHostPort(val)
if err != nil { if err != nil {
log.Errorf("failed to parse endpoint: %v", err) log.Errorf("failed to parse endpoint: %v", err)
continue continue

View File

@@ -8,8 +8,6 @@ import (
"net" "net"
"sync" "sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -26,16 +24,6 @@ const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -253,63 +241,3 @@ generatePort:
} }
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var dstIP net.IP
var rawConn net.PacketConn
if endpointAddr.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
dstIP = localHostNetIPv4
rawConn = p.rawConnIPv4
} else {
// IPv6 path
if p.rawConnIPv6 == nil {
return fmt.Errorf("IPv6 raw socket not available")
}
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpointAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
dstIP = localHostNetIPv6
rawConn = p.rawConnIPv6
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,12 +10,89 @@ import (
"net" "net"
"sync" "sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool paused bool
pausedCond *sync.Cond pausedCond *sync.Cond
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr p.wgRelayedEndpointAddr = addr
return err p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
@@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
log.Errorf("failed to start package redirection, endpoint is nil") log.Errorf("failed to start package redirection, endpoint is nil")
return return
} }
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal() p.pausedCond.Signal()
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
@@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait() p.pausedCond.Wait()
} }
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) err = p.sendPkg(buf[:n], p.headerCurrentUsed)
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
@@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
return n, nil return n, nil
} }
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false, matchSubdomains: false,
shouldMatch: false, shouldMatch: false,
}, },
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -9,8 +9,10 @@ import (
"io" "io"
"net/netip" "net/netip"
"os/exec" "os/exec"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -38,6 +40,9 @@ const (
type systemConfigurator struct { type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
} }
func newHostManager() (*systemConfigurator, error) { func newHostManager() (*systemConfigurator, error) {
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
var dnsSettings SystemDNSSettings var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false inSearchDomainsArray := false
inServerAddressesArray := false inServerAddressesArray := false
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
dnsSettings.ServerIP = ip.Unmap() ip = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
} }
} }
} }
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port // default to 53 port
dnsSettings.ServerPort = DefaultPort dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil return dnsSettings, nil
} }
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {

View File

@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput() _, err := cmd.CombinedOutput()
return err return err
} }
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -42,6 +42,10 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions" dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8 dnsPolicyConfigConfigOptionsValue = 0x8
// NRPT rules cannot handle more than 50 domains per rule.
// This is an undocumented Windows limitation.
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
@@ -239,23 +243,32 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain} // NRPT rules have an undocumented restriction: each rule can only handle up to 50 domains.
// We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
}
batchDomains := domains[i:end]
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
} }
if r.gpo { if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err) return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex, err)
} }
} }
log.Debugf("added NRPT entry for domain: %s", domain) log.Debugf("added NRPT rule %d with %d domains", ruleIndex, len(batchDomains))
ruleIndex++
} }
if r.gpo { if r.gpo {
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
} }
} }
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains) log.Infof("added %d NRPT rules for %d domains. Domain list: %s", ruleIndex, len(domains), domains)
return len(domains), nil return ruleIndex, nil
} }
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -97,6 +97,107 @@ func registryKeyExists(path string) (bool, error) {
} }
func cleanupRegistryKeys(*testing.T) { func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10} // Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
_ = cfg.removeDNSMatchPolicies() _ = cfg.removeDNSMatchPolicies()
} }
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules
// with a maximum of 50 domains per rule (Windows limitation).
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -84,3 +84,13 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil return nil
} }
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}

View File

@@ -41,6 +41,8 @@ type IosDnsManager interface {
type Server interface { type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int) RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int) DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
Initialize() error Initialize() error
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
@@ -83,6 +85,7 @@ type DefaultServer struct {
currentConfigHash uint64 currentConfigHash uint64
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver mgmtCacheResolver *mgmt.Resolver
@@ -230,7 +233,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter // convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++ s.extraDomains[toZone(domain)]++
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -259,6 +264,28 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone) delete(s.extraDomains, zone)
} }
} }
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig() s.applyHostConfig()
} }
@@ -508,7 +535,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
s.shutdownWg.Add(1) s.shutdownWg.Add(1)
go func() { go func() {
@@ -615,7 +644,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config) s.registerFallback(config)
} }
// registerFallback registers original nameservers as low-priority fallback handlers // registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) { func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok { if !ok {
@@ -624,6 +653,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers() originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 { if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return return
} }
@@ -871,7 +901,9 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
go func() { go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil { if err := s.stateManager.PersistState(s.ctx); err != nil {
@@ -906,7 +938,9 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority) s.registerHandler([]string{nbdns.RootZone}, handler, priority)
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
s.updateNSState(nsGroup, nil, true) s.updateNSState(nsGroup, nil, true)
} }

View File

@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err) t.Errorf("invalid dns server instance: %s", err)
} }
if srvB != srv { mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
t.Errorf("mismatch dns instances") t.Errorf("mismatch dns instances")
} }
} }

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct { type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
} }
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil { if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m) return rw.WriteMsgFunc(m)
} }
return nil return nil
} }
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -573,9 +573,11 @@ func (e *Engine) createFirewall() error {
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) return fmt.Errorf("create firewall manager: %w", err)
return nil }
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
} }
if err := e.initFirewall(); err != nil { if err := e.initFirewall(); err != nil {
@@ -826,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
} }
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()

View File

@@ -14,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error. // Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock() nw.mu.Lock()
if nw.cancel != nil { if nw.cancel != nil {
nw.mu.Unlock() nw.mu.Unlock()

View File

@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
} }
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy) conn.handleConfigurationFailure(err, wgProxy)
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep) conn.wgProxyRelay.RedirectAs(ep)
} }
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgProxy.Work() wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey) presharedKey := conn.presharedKey(rci.rosenpassPubKey)
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err) conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
conn.enableWgWatcherIfNeeded()
wgConfigWorkaround() wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay

View File

@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
} }
func (m *DefaultManager) setupRefCounters(useNoop bool) { func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) { func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
}, },
func(prefix netip.Prefix, _ struct{}) error { func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) return m.sysOps.RemoveVPNRoute(prefix, toInterface())
}, },
) )
@@ -337,6 +346,13 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
} }
var merr *multierror.Error var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
defer m.dnsServer.EndBatch()
}
for id, handler := range toRemove { for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil { if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err)) merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
) )
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine // WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms") return false, errors.New("not supported on mobile platforms")
} }
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" { if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor") log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name") return false, errors.New("empty interface name")

View File

@@ -16,13 +16,13 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server" "github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt" "github.com/netbirdio/netbird/util/crypt"
) )
@@ -78,9 +78,8 @@ var (
} }
} }
_, valid := dns.IsDomainName(dnsDomain) if !nbdomain.IsValidDomainNoWildcard(dnsDomain) {
if !valid || len(dnsDomain) > 192 { return fmt.Errorf("invalid dns-domain: %s", dnsDomain)
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain))
} }
return nil return nil

View File

@@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
} }
for accountID, peerIDs := range peerIDsPerAccount { for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
} }
} }
} }

View File

@@ -108,10 +108,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID)
return nil
}
return err return err
} }
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)",
peerID, peer.Status.Connected,
peer.Status.LastSeen.Format(time.RFC3339),
time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339),
peer.Ephemeral)
return nil return nil
} }
@@ -150,7 +159,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return nil return nil
}) })
if err != nil { if err != nil {
return err log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err)
continue
} }
if m.integratedPeerValidator != nil { if m.integratedPeerValidator != nil {

View File

@@ -6,7 +6,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
) )
@@ -63,7 +63,7 @@ func (r *Record) Validate() error {
return errors.New("record name is required") return errors.New("record name is required")
} }
if !util.IsValidDomain(r.Name) { if !domain.IsValidDomain(r.Name) {
return errors.New("invalid record name format") return errors.New("invalid record name format")
} }
@@ -81,8 +81,8 @@ func (r *Record) Validate() error {
return err return err
} }
case RecordTypeCNAME: case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) { if !domain.IsValidDomainNoWildcard(r.Content) {
return errors.New("invalid CNAME record format") return errors.New("invalid CNAME target format")
} }
default: default:
return errors.New("invalid record type, must be A, AAAA, or CNAME") return errors.New("invalid record type, must be A, AAAA, or CNAME")

View File

@@ -6,7 +6,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
) )
@@ -73,7 +73,7 @@ func (z *Zone) Validate() error {
return errors.New("zone name exceeds maximum length of 255 characters") return errors.New("zone name exceeds maximum length of 255 characters")
} }
if !util.IsValidDomain(z.Domain) { if !domain.IsValidDomainNoWildcard(z.Domain) {
return errors.New("invalid zone domain format") return errors.New("invalid zone domain format")
} }

View File

@@ -17,13 +17,14 @@ import (
pb "github.com/golang/protobuf/proto" // nolint pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/shared/management/client/common"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -304,6 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
return err return err
} }

View File

@@ -26,6 +26,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -231,7 +232,7 @@ func BuildManager(
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
if am.singleAccountMode { if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) { if !nbdomain.IsValidDomainNoWildcard(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
} }
am.singleAccountModeDomain = singleAccountModeDomain am.singleAccountModeDomain = singleAccountModeDomain
@@ -402,7 +403,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { if newSettings.DNSDomain != "" && !nbdomain.IsValidDomainNoWildcard(newSettings.DNSDomain) {
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
@@ -1691,10 +1692,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return nil return nil
} }
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) // isDomainValid validates public/IDP domains using stricter rules than internal DNS domains.
// Requires at least 2-char alphabetic TLD and no single-label domains.
var publicDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool { func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain) return publicDomainRegexp.MatchString(domain)
} }
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {

View File

@@ -3,10 +3,10 @@ package server
import ( import (
"context" "context"
"errors" "errors"
"regexp" "fmt"
"strings"
"unicode/utf8" "unicode/utf8"
"github.com/miekg/dns"
"github.com/rs/xid" "github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -15,11 +15,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
var errInvalidDomainName = errors.New("invalid domain name") var errInvalidDomainName = errors.New("invalid domain name")
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
@@ -305,16 +304,18 @@ func validateGroups(list []string, groups map[string]*types.Group) error {
return nil return nil
} }
var domainMatcher = regexp.MustCompile(domainPattern) // validateDomain validates a nameserver match domain.
// Converts unicode to punycode. Wildcards are not allowed for nameservers.
func validateDomain(domain string) error { func validateDomain(d string) error {
if !domainMatcher.MatchString(domain) { if strings.HasPrefix(d, "*.") {
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") return errors.New("wildcards not allowed")
} }
_, valid := dns.IsDomainName(domain) // Nameservers allow trailing dot (FQDN format)
if !valid { toValidate := strings.TrimSuffix(d, ".")
return errInvalidDomainName
if _, err := nbdomain.ValidateDomains([]string{toValidate}); err != nil {
return fmt.Errorf("%w: %w", errInvalidDomainName, err)
} }
return nil return nil

View File

@@ -901,82 +901,53 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
return account, nil return account, nil
} }
// TestValidateDomain tests nameserver-specific domain validation.
// Core domain validation is tested in shared/management/domain/validate_test.go.
// This test only covers nameserver-specific behavior: wildcard rejection and unicode support.
func TestValidateDomain(t *testing.T) { func TestValidateDomain(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
domain string domain string
errFunc require.ErrorAssertionFunc errFunc require.ErrorAssertionFunc
}{ }{
// Nameserver-specific: wildcards not allowed
{ {
name: "Valid domain name with multiple labels", name: "Wildcard prefix rejected",
domain: "123.example.com", domain: "*.example.com",
errFunc: require.Error,
},
{
name: "Wildcard in middle rejected",
domain: "a.*.example.com",
errFunc: require.Error,
},
// Nameserver-specific: unicode converted to punycode
{
name: "Unicode domain converted to punycode",
domain: "münchen.de",
errFunc: require.NoError, errFunc: require.NoError,
}, },
{ {
name: "Valid domain name with hyphen", name: "Unicode domain all labels",
domain: "test-example.com", domain: "中国.中国",
errFunc: require.NoError,
},
// Basic validation still works (delegates to shared validation)
{
name: "Valid multi-label domain",
domain: "example.com",
errFunc: require.NoError, errFunc: require.NoError,
}, },
{ {
name: "Valid domain name with only one label", name: "Valid single label",
domain: "example", domain: "internal",
errFunc: require.NoError, errFunc: require.NoError,
}, },
{ {
name: "Valid domain name with trailing dot", name: "Invalid leading hyphen",
domain: "example.",
errFunc: require.NoError,
},
{
name: "Invalid wildcard domain name",
domain: "*.example",
errFunc: require.Error,
},
{
name: "Invalid domain name with leading dot",
domain: ".com",
errFunc: require.Error,
},
{
name: "Invalid domain name with dot only",
domain: ".",
errFunc: require.Error,
},
{
name: "Invalid domain name with double hyphen",
domain: "test--example.com",
errFunc: require.Error,
},
{
name: "Invalid domain name with a label exceeding 63 characters",
domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com",
errFunc: require.Error,
},
{
name: "Invalid domain name starting with a hyphen",
domain: "-example.com", domain: "-example.com",
errFunc: require.Error, errFunc: require.Error,
}, },
{
name: "Invalid domain name ending with a hyphen",
domain: "example.com-",
errFunc: require.Error,
},
{
name: "Invalid domain with unicode",
domain: "example?,.com",
errFunc: require.Error,
},
{
name: "Invalid domain with space before top-level domain",
domain: "space .example.com",
errFunc: require.Error,
},
{
name: "Invalid domain with trailing space",
domain: "example.com ",
errFunc: require.Error,
},
} }
for _, testCase := range testCases { for _, testCase := range testCases {

View File

@@ -203,7 +203,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
NetworkID: "testNetworkId", NetworkID: "testNetworkId",
Name: "testResourceId", Name: "testResourceId",
Description: "description", Description: "description",
Address: "invalid-address", Address: "-invalid",
} }
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -227,9 +227,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
resource := &types.NetworkResource{ resource := &types.NetworkResource{
AccountID: "testAccountId", AccountID: "testAccountId",
NetworkID: "testNetworkId", NetworkID: "testNetworkId",
Name: "testResourceId", Name: "used-name",
Description: "description", Description: "description",
Address: "invalid-address", Address: "example.com",
} }
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"regexp"
"github.com/rs/xid" "github.com/rs/xid"
@@ -166,8 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix,
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
} }
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) if _, err := nbDomain.ValidateDomains([]string{address}); err == nil {
if domainRegex.MatchString(address) {
return Domain, address, netip.Prefix{}, nil return Domain, address, netip.Prefix{}, nil
} }

View File

@@ -23,10 +23,12 @@ func TestGetResourceType(t *testing.T) {
{"example.com", Domain, false, "example.com", netip.Prefix{}}, {"example.com", Domain, false, "example.com", netip.Prefix{}},
{"*.example.com", Domain, false, "*.example.com", netip.Prefix{}}, {"*.example.com", Domain, false, "*.example.com", netip.Prefix{}},
{"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}}, {"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}},
{"example.x", Domain, false, "example.x", netip.Prefix{}},
{"internal", Domain, false, "internal", netip.Prefix{}},
// Invalid inputs // Invalid inputs
{"invalid", "", true, "", netip.Prefix{}},
{"1.1.1.1/abc", "", true, "", netip.Prefix{}}, {"1.1.1.1/abc", "", true, "", netip.Prefix{}},
{"1234", "", true, "", netip.Prefix{}}, {"-invalid.com", "", true, "", netip.Prefix{}},
{"", "", true, "", netip.Prefix{}},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -728,11 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err) return fmt.Errorf("failed adding peer to All group: %w", err)
} }
if temporary {
// we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
if addedByUser { if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil { if err != nil {
@@ -760,6 +755,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
if ephemeral {
// we should track ephemeral peers to be able to clean them if the peer doesn't sync and isn't marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil return nil
}) })

View File

@@ -1,9 +1,5 @@
package util package util
import "regexp"
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
// Difference returns the elements in `a` that aren't in `b`. // Difference returns the elements in `a` that aren't in `b`.
func Difference(a, b []string) []string { func Difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b)) mb := make(map[string]struct{}, len(b))
@@ -55,9 +51,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
return false return false
} }
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}

View File

@@ -10,7 +10,30 @@ const maxDomains = 32
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. // IsValidDomain checks if a single domain string is valid.
// Does not convert unicode to punycode - domain must already be ASCII/punycode.
// Allows wildcard prefix (*.example.com).
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(strings.ToLower(domain))
}
// IsValidDomainNoWildcard checks if a single domain string is valid without wildcard prefix.
// Use for zone domains and CNAME targets where wildcards are not allowed.
func IsValidDomainNoWildcard(domain string) bool {
if domain == "" {
return false
}
if strings.HasPrefix(domain, "*.") {
return false
}
return domainRegex.MatchString(strings.ToLower(domain))
}
// ValidateDomains validates domains and converts unicode to punycode.
// Allows wildcard prefix (*.example.com). Maximum 32 domains.
func ValidateDomains(domains []string) (List, error) { func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 { if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty") return nil, fmt.Errorf("domains list is empty")
@@ -37,7 +60,10 @@ func ValidateDomains(domains []string) (List, error) {
return domainList, nil return domainList, nil
} }
// ValidateDomainsList checks if each domain in the list is valid // ValidateDomainsList validates domains without punycode conversion.
// Use this for domains that must already be in ASCII/punycode format (e.g., extra DNS labels).
// Unlike ValidateDomains, this does not convert unicode to punycode - unicode domains will fail.
// Allows wildcard prefix (*.example.com). Maximum 32 domains.
func ValidateDomainsList(domains []string) error { func ValidateDomainsList(domains []string) error {
if len(domains) == 0 { if len(domains) == 0 {
return nil return nil

View File

@@ -2,12 +2,16 @@ package domain
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestValidateDomains(t *testing.T) { func TestValidateDomains(t *testing.T) {
label63 := strings.Repeat("a", 63)
label64 := strings.Repeat("a", 64)
tests := []struct { tests := []struct {
name string name string
domains []string domains []string
@@ -26,6 +30,48 @@ func TestValidateDomains(t *testing.T) {
expected: List{"sub.ex-ample.com"}, expected: List{"sub.ex-ample.com"},
wantErr: false, wantErr: false,
}, },
{
name: "Valid uppercase domain normalized to lowercase",
domains: []string{"EXAMPLE.COM"},
expected: List{"example.com"},
wantErr: false,
},
{
name: "Valid mixed case domain",
domains: []string{"ExAmPlE.CoM"},
expected: List{"example.com"},
wantErr: false,
},
{
name: "Single letter TLD",
domains: []string{"example.x"},
expected: List{"example.x"},
wantErr: false,
},
{
name: "Two letter domain labels",
domains: []string{"a.b"},
expected: List{"a.b"},
wantErr: false,
},
{
name: "Single character domain",
domains: []string{"x"},
expected: List{"x"},
wantErr: false,
},
{
name: "Wildcard with single letter TLD",
domains: []string{"*.x"},
expected: List{"*.x"},
wantErr: false,
},
{
name: "Multi-level with single letter labels",
domains: []string{"a.b.c"},
expected: List{"a.b.c"},
wantErr: false,
},
{ {
name: "Valid Unicode domain", name: "Valid Unicode domain",
domains: []string{"münchen.de"}, domains: []string{"münchen.de"},
@@ -45,17 +91,92 @@ func TestValidateDomains(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
name: "Invalid domain format", name: "Valid domain starting with digit",
domains: []string{"123.example.com"},
expected: List{"123.example.com"},
wantErr: false,
},
// Numeric TLDs are allowed for internal/private DNS use cases.
// While ICANN doesn't issue all-numeric gTLDs, the DNS protocol permits them
// and resolvers like systemd-resolved handle them correctly.
{
name: "Numeric TLD allowed",
domains: []string{"example.123"},
expected: List{"example.123"},
wantErr: false,
},
{
name: "Single digit TLD allowed",
domains: []string{"example.1"},
expected: List{"example.1"},
wantErr: false,
},
{
name: "All numeric labels allowed",
domains: []string{"123.456"},
expected: List{"123.456"},
wantErr: false,
},
{
name: "Single numeric label allowed",
domains: []string{"123"},
expected: List{"123"},
wantErr: false,
},
{
name: "Valid domain with double hyphen",
domains: []string{"test--example.com"},
expected: List{"test--example.com"},
wantErr: false,
},
{
name: "Invalid leading hyphen",
domains: []string{"-example.com"}, domains: []string{"-example.com"},
expected: nil, expected: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "Invalid domain format 2", name: "Invalid trailing hyphen",
domains: []string{"example.com-"}, domains: []string{"example.com-"},
expected: nil, expected: nil,
wantErr: true, wantErr: true,
}, },
{
name: "Invalid leading dot",
domains: []string{".com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid dot only",
domains: []string{"."},
expected: nil,
wantErr: true,
},
{
name: "Invalid double dot",
domains: []string{"example..com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid special characters",
domains: []string{"example?,.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid space in domain",
domains: []string{"space .example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid trailing space",
domains: []string{"example.com "},
expected: nil,
wantErr: true,
},
{ {
name: "Multiple domains valid and invalid", name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
@@ -86,6 +207,30 @@ func TestValidateDomains(t *testing.T) {
expected: nil, expected: nil,
wantErr: true, wantErr: true,
}, },
{
name: "Valid 63 char label (max)",
domains: []string{label63 + ".com"},
expected: List{Domain(label63 + ".com")},
wantErr: false,
},
{
name: "Invalid 64 char label (exceeds max)",
domains: []string{label64 + ".com"},
expected: nil,
wantErr: true,
},
{
name: "Valid 253 char domain (max)",
domains: []string{strings.Repeat("a.", 126) + "a"},
expected: List{Domain(strings.Repeat("a.", 126) + "a")},
wantErr: false,
},
{
name: "Invalid 254+ char domain (exceeds max)",
domains: []string{strings.Repeat("ab.", 85)},
expected: nil,
wantErr: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -118,6 +263,57 @@ func TestValidateDomainsList(t *testing.T) {
domains: []string{"sub.ex-ample.com"}, domains: []string{"sub.ex-ample.com"},
wantErr: false, wantErr: false,
}, },
{
name: "Uppercase domain accepted",
domains: []string{"EXAMPLE.COM"},
wantErr: false,
},
{
name: "Single letter TLD",
domains: []string{"example.x"},
wantErr: false,
},
{
name: "Two letter domain labels",
domains: []string{"a.b"},
wantErr: false,
},
{
name: "Single character domain",
domains: []string{"x"},
wantErr: false,
},
{
name: "Wildcard with single letter TLD",
domains: []string{"*.x"},
wantErr: false,
},
{
name: "Multi-level with single letter labels",
domains: []string{"a.b.c"},
wantErr: false,
},
// Numeric TLDs are allowed for internal/private DNS use cases.
{
name: "Numeric TLD allowed",
domains: []string{"example.123"},
wantErr: false,
},
{
name: "Single digit TLD allowed",
domains: []string{"example.1"},
wantErr: false,
},
{
name: "All numeric labels allowed",
domains: []string{"123.456"},
wantErr: false,
},
{
name: "Single numeric label allowed",
domains: []string{"123"},
wantErr: false,
},
{ {
name: "Underscores in labels", name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"}, domains: []string{"_jabber._tcp.gmail.com"},