Merge branch 'main' into prototype/reverse-proxy

This commit is contained in:
Alisdair MacLeod
2026-01-30 14:46:08 +00:00
28 changed files with 729 additions and 203 deletions

View File

@@ -69,6 +69,8 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
}
// validateCredentials checks that exactly one credential type is provided
@@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)

View File

@@ -8,8 +8,6 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
@@ -26,16 +24,6 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
@@ -253,63 +241,3 @@ generatePort:
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var dstIP net.IP
var rawConn net.PacketConn
if endpointAddr.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
dstIP = localHostNetIPv4
rawConn = p.rawConnIPv4
} else {
// IPv6 path
if p.rawConnIPv6 == nil {
return fmt.Errorf("IPv6 raw socket not available")
}
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpointAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
dstIP = localHostNetIPv6
rawConn = p.rawConnIPv6
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,12 +10,89 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool
pausedCond *sync.Cond
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
return err
p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted {
p.isStarted = true
@@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
@@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
p.pausedCond.L.Unlock()
if err != nil {
@@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
return n, nil
}
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false,
shouldMatch: false,
},
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
}
for _, tt := range tests {

View File

@@ -9,8 +9,10 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -38,6 +40,9 @@ const (
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
}
func newHostManager() (*systemConfigurator, error) {
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false
inServerAddressesArray := false
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
}
}
}
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port
dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil
}
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {

View File

@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config)
}
// registerFallback registers original nameservers as low-priority fallback handlers
// registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
@@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return
}

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -573,9 +573,11 @@ func (e *Engine) createFirewall() error {
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
if err != nil {
return fmt.Errorf("create firewall manager: %w", err)
}
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
}
if err := e.initFirewall(); err != nil {

View File

@@ -14,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()

View File

@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.enableWgWatcherIfNeeded()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay

View File

@@ -9,6 +9,8 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms")
}
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name")

View File

@@ -16,13 +16,13 @@ import (
"strings"
"syscall"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
)
@@ -78,9 +78,8 @@ var (
}
}
_, valid := dns.IsDomainName(dnsDomain)
if !valid || len(dnsDomain) > 192 {
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain))
if !nbdomain.IsValidDomainNoWildcard(dnsDomain) {
return fmt.Errorf("invalid dns-domain: %s", dnsDomain)
}
return nil

View File

@@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
}
}
}

View File

@@ -108,10 +108,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID)
return nil
}
return err
}
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)",
peerID, peer.Status.Connected,
peer.Status.LastSeen.Format(time.RFC3339),
time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339),
peer.Ephemeral)
return nil
}
@@ -150,7 +159,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return nil
})
if err != nil {
return err
log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err)
continue
}
if m.integratedPeerValidator != nil {

View File

@@ -6,7 +6,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -63,7 +63,7 @@ func (r *Record) Validate() error {
return errors.New("record name is required")
}
if !util.IsValidDomain(r.Name) {
if !domain.IsValidDomain(r.Name) {
return errors.New("invalid record name format")
}
@@ -81,8 +81,8 @@ func (r *Record) Validate() error {
return err
}
case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) {
return errors.New("invalid CNAME record format")
if !domain.IsValidDomainNoWildcard(r.Content) {
return errors.New("invalid CNAME target format")
}
default:
return errors.New("invalid record type, must be A, AAAA, or CNAME")

View File

@@ -6,7 +6,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -73,7 +73,7 @@ func (z *Zone) Validate() error {
return errors.New("zone name exceeds maximum length of 255 characters")
}
if !util.IsValidDomain(z.Domain) {
if !domain.IsValidDomainNoWildcard(z.Domain) {
return errors.New("invalid zone domain format")
}

View File

@@ -17,13 +17,14 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/shared/management/client/common"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
@@ -304,6 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}

View File

@@ -26,6 +26,7 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -231,7 +232,7 @@ func BuildManager(
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
if !nbdomain.IsValidDomainNoWildcard(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
@@ -402,7 +403,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
if newSettings.DNSDomain != "" && !nbdomain.IsValidDomainNoWildcard(newSettings.DNSDomain) {
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
@@ -1691,10 +1692,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return nil
}
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
// isDomainValid validates public/IDP domains using stricter rules than internal DNS domains.
// Requires at least 2-char alphabetic TLD and no single-label domains.
var publicDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain)
return publicDomainRegexp.MatchString(domain)
}
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {

View File

@@ -3,10 +3,10 @@ package server
import (
"context"
"errors"
"regexp"
"fmt"
"strings"
"unicode/utf8"
"github.com/miekg/dns"
"github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns"
@@ -15,11 +15,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
var errInvalidDomainName = errors.New("invalid domain name")
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
@@ -305,16 +304,18 @@ func validateGroups(list []string, groups map[string]*types.Group) error {
return nil
}
var domainMatcher = regexp.MustCompile(domainPattern)
func validateDomain(domain string) error {
if !domainMatcher.MatchString(domain) {
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
// validateDomain validates a nameserver match domain.
// Converts unicode to punycode. Wildcards are not allowed for nameservers.
func validateDomain(d string) error {
if strings.HasPrefix(d, "*.") {
return errors.New("wildcards not allowed")
}
_, valid := dns.IsDomainName(domain)
if !valid {
return errInvalidDomainName
// Nameservers allow trailing dot (FQDN format)
toValidate := strings.TrimSuffix(d, ".")
if _, err := nbdomain.ValidateDomains([]string{toValidate}); err != nil {
return fmt.Errorf("%w: %w", errInvalidDomainName, err)
}
return nil

View File

@@ -901,82 +901,53 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
return account, nil
}
// TestValidateDomain tests nameserver-specific domain validation.
// Core domain validation is tested in shared/management/domain/validate_test.go.
// This test only covers nameserver-specific behavior: wildcard rejection and unicode support.
func TestValidateDomain(t *testing.T) {
testCases := []struct {
name string
domain string
errFunc require.ErrorAssertionFunc
}{
// Nameserver-specific: wildcards not allowed
{
name: "Valid domain name with multiple labels",
domain: "123.example.com",
name: "Wildcard prefix rejected",
domain: "*.example.com",
errFunc: require.Error,
},
{
name: "Wildcard in middle rejected",
domain: "a.*.example.com",
errFunc: require.Error,
},
// Nameserver-specific: unicode converted to punycode
{
name: "Unicode domain converted to punycode",
domain: "münchen.de",
errFunc: require.NoError,
},
{
name: "Valid domain name with hyphen",
domain: "test-example.com",
name: "Unicode domain all labels",
domain: "中国.中国",
errFunc: require.NoError,
},
// Basic validation still works (delegates to shared validation)
{
name: "Valid multi-label domain",
domain: "example.com",
errFunc: require.NoError,
},
{
name: "Valid domain name with only one label",
domain: "example",
name: "Valid single label",
domain: "internal",
errFunc: require.NoError,
},
{
name: "Valid domain name with trailing dot",
domain: "example.",
errFunc: require.NoError,
},
{
name: "Invalid wildcard domain name",
domain: "*.example",
errFunc: require.Error,
},
{
name: "Invalid domain name with leading dot",
domain: ".com",
errFunc: require.Error,
},
{
name: "Invalid domain name with dot only",
domain: ".",
errFunc: require.Error,
},
{
name: "Invalid domain name with double hyphen",
domain: "test--example.com",
errFunc: require.Error,
},
{
name: "Invalid domain name with a label exceeding 63 characters",
domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com",
errFunc: require.Error,
},
{
name: "Invalid domain name starting with a hyphen",
name: "Invalid leading hyphen",
domain: "-example.com",
errFunc: require.Error,
},
{
name: "Invalid domain name ending with a hyphen",
domain: "example.com-",
errFunc: require.Error,
},
{
name: "Invalid domain with unicode",
domain: "example?,.com",
errFunc: require.Error,
},
{
name: "Invalid domain with space before top-level domain",
domain: "space .example.com",
errFunc: require.Error,
},
{
name: "Invalid domain with trailing space",
domain: "example.com ",
errFunc: require.Error,
},
}
for _, testCase := range testCases {

View File

@@ -203,7 +203,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
NetworkID: "testNetworkId",
Name: "testResourceId",
Description: "description",
Address: "invalid-address",
Address: "-invalid",
}
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -227,9 +227,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
Name: "testResourceId",
Name: "used-name",
Description: "description",
Address: "invalid-address",
Address: "example.com",
}
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())

View File

@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/netip"
"regexp"
"github.com/rs/xid"
@@ -166,8 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix,
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
}
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
if domainRegex.MatchString(address) {
if _, err := nbDomain.ValidateDomains([]string{address}); err == nil {
return Domain, address, netip.Prefix{}, nil
}

View File

@@ -23,10 +23,12 @@ func TestGetResourceType(t *testing.T) {
{"example.com", Domain, false, "example.com", netip.Prefix{}},
{"*.example.com", Domain, false, "*.example.com", netip.Prefix{}},
{"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}},
{"example.x", Domain, false, "example.x", netip.Prefix{}},
{"internal", Domain, false, "internal", netip.Prefix{}},
// Invalid inputs
{"invalid", "", true, "", netip.Prefix{}},
{"1.1.1.1/abc", "", true, "", netip.Prefix{}},
{"1234", "", true, "", netip.Prefix{}},
{"-invalid.com", "", true, "", netip.Prefix{}},
{"", "", true, "", netip.Prefix{}},
}
for _, tt := range tests {

View File

@@ -728,11 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if temporary {
// we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -760,6 +755,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed to increment network serial: %w", err)
}
if ephemeral {
// we should track ephemeral peers to be able to clean them if the peer doesn't sync and isn't marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})

View File

@@ -1,9 +1,5 @@
package util
import "regexp"
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
// Difference returns the elements in `a` that aren't in `b`.
func Difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
@@ -55,9 +51,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
return false
}
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}

View File

@@ -10,7 +10,30 @@ const maxDomains = 32
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
// IsValidDomain checks if a single domain string is valid.
// Does not convert unicode to punycode - domain must already be ASCII/punycode.
// Allows wildcard prefix (*.example.com).
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(strings.ToLower(domain))
}
// IsValidDomainNoWildcard checks if a single domain string is valid without wildcard prefix.
// Use for zone domains and CNAME targets where wildcards are not allowed.
func IsValidDomainNoWildcard(domain string) bool {
if domain == "" {
return false
}
if strings.HasPrefix(domain, "*.") {
return false
}
return domainRegex.MatchString(strings.ToLower(domain))
}
// ValidateDomains validates domains and converts unicode to punycode.
// Allows wildcard prefix (*.example.com). Maximum 32 domains.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
@@ -37,7 +60,10 @@ func ValidateDomains(domains []string) (List, error) {
return domainList, nil
}
// ValidateDomainsList checks if each domain in the list is valid
// ValidateDomainsList validates domains without punycode conversion.
// Use this for domains that must already be in ASCII/punycode format (e.g., extra DNS labels).
// Unlike ValidateDomains, this does not convert unicode to punycode - unicode domains will fail.
// Allows wildcard prefix (*.example.com). Maximum 32 domains.
func ValidateDomainsList(domains []string) error {
if len(domains) == 0 {
return nil

View File

@@ -2,12 +2,16 @@ package domain
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateDomains(t *testing.T) {
label63 := strings.Repeat("a", 63)
label64 := strings.Repeat("a", 64)
tests := []struct {
name string
domains []string
@@ -26,6 +30,48 @@ func TestValidateDomains(t *testing.T) {
expected: List{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Valid uppercase domain normalized to lowercase",
domains: []string{"EXAMPLE.COM"},
expected: List{"example.com"},
wantErr: false,
},
{
name: "Valid mixed case domain",
domains: []string{"ExAmPlE.CoM"},
expected: List{"example.com"},
wantErr: false,
},
{
name: "Single letter TLD",
domains: []string{"example.x"},
expected: List{"example.x"},
wantErr: false,
},
{
name: "Two letter domain labels",
domains: []string{"a.b"},
expected: List{"a.b"},
wantErr: false,
},
{
name: "Single character domain",
domains: []string{"x"},
expected: List{"x"},
wantErr: false,
},
{
name: "Wildcard with single letter TLD",
domains: []string{"*.x"},
expected: List{"*.x"},
wantErr: false,
},
{
name: "Multi-level with single letter labels",
domains: []string{"a.b.c"},
expected: List{"a.b.c"},
wantErr: false,
},
{
name: "Valid Unicode domain",
domains: []string{"münchen.de"},
@@ -45,17 +91,92 @@ func TestValidateDomains(t *testing.T) {
wantErr: false,
},
{
name: "Invalid domain format",
name: "Valid domain starting with digit",
domains: []string{"123.example.com"},
expected: List{"123.example.com"},
wantErr: false,
},
// Numeric TLDs are allowed for internal/private DNS use cases.
// While ICANN doesn't issue all-numeric gTLDs, the DNS protocol permits them
// and resolvers like systemd-resolved handle them correctly.
{
name: "Numeric TLD allowed",
domains: []string{"example.123"},
expected: List{"example.123"},
wantErr: false,
},
{
name: "Single digit TLD allowed",
domains: []string{"example.1"},
expected: List{"example.1"},
wantErr: false,
},
{
name: "All numeric labels allowed",
domains: []string{"123.456"},
expected: List{"123.456"},
wantErr: false,
},
{
name: "Single numeric label allowed",
domains: []string{"123"},
expected: List{"123"},
wantErr: false,
},
{
name: "Valid domain with double hyphen",
domains: []string{"test--example.com"},
expected: List{"test--example.com"},
wantErr: false,
},
{
name: "Invalid leading hyphen",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid domain format 2",
name: "Invalid trailing hyphen",
domains: []string{"example.com-"},
expected: nil,
wantErr: true,
},
{
name: "Invalid leading dot",
domains: []string{".com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid dot only",
domains: []string{"."},
expected: nil,
wantErr: true,
},
{
name: "Invalid double dot",
domains: []string{"example..com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid special characters",
domains: []string{"example?,.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid space in domain",
domains: []string{"space .example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid trailing space",
domains: []string{"example.com "},
expected: nil,
wantErr: true,
},
{
name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
@@ -86,6 +207,30 @@ func TestValidateDomains(t *testing.T) {
expected: nil,
wantErr: true,
},
{
name: "Valid 63 char label (max)",
domains: []string{label63 + ".com"},
expected: List{Domain(label63 + ".com")},
wantErr: false,
},
{
name: "Invalid 64 char label (exceeds max)",
domains: []string{label64 + ".com"},
expected: nil,
wantErr: true,
},
{
name: "Valid 253 char domain (max)",
domains: []string{strings.Repeat("a.", 126) + "a"},
expected: List{Domain(strings.Repeat("a.", 126) + "a")},
wantErr: false,
},
{
name: "Invalid 254+ char domain (exceeds max)",
domains: []string{strings.Repeat("ab.", 85)},
expected: nil,
wantErr: true,
},
}
for _, tt := range tests {
@@ -118,6 +263,57 @@ func TestValidateDomainsList(t *testing.T) {
domains: []string{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Uppercase domain accepted",
domains: []string{"EXAMPLE.COM"},
wantErr: false,
},
{
name: "Single letter TLD",
domains: []string{"example.x"},
wantErr: false,
},
{
name: "Two letter domain labels",
domains: []string{"a.b"},
wantErr: false,
},
{
name: "Single character domain",
domains: []string{"x"},
wantErr: false,
},
{
name: "Wildcard with single letter TLD",
domains: []string{"*.x"},
wantErr: false,
},
{
name: "Multi-level with single letter labels",
domains: []string{"a.b.c"},
wantErr: false,
},
// Numeric TLDs are allowed for internal/private DNS use cases.
{
name: "Numeric TLD allowed",
domains: []string{"example.123"},
wantErr: false,
},
{
name: "Single digit TLD allowed",
domains: []string{"example.1"},
wantErr: false,
},
{
name: "All numeric labels allowed",
domains: []string{"123.456"},
wantErr: false,
},
{
name: "Single numeric label allowed",
domains: []string{"123"},
wantErr: false,
},
{
name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"},