diff --git a/client/firewall/firewalld/firewalld.go b/client/firewall/firewalld/firewalld.go new file mode 100644 index 000000000..188ea61dd --- /dev/null +++ b/client/firewall/firewalld/firewalld.go @@ -0,0 +1,11 @@ +// Package firewalld integrates with the firewalld daemon so NetBird can place +// its wg interface into firewalld's "trusted" zone. This is required because +// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent +// versions, which returns EPERM to any other process that tries to insert +// rules into them. The workaround mirrors what Tailscale does: let firewalld +// itself add the accept rules to its own chains by trusting the interface. +package firewalld + +// TrustedZone is the firewalld zone name used for interfaces whose traffic +// should bypass firewalld filtering. +const TrustedZone = "trusted" diff --git a/client/firewall/firewalld/firewalld_linux.go b/client/firewall/firewalld/firewalld_linux.go new file mode 100644 index 000000000..924a04b0a --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux.go @@ -0,0 +1,260 @@ +//go:build linux + +package firewalld + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "sync" + "time" + + "github.com/godbus/dbus/v5" + log "github.com/sirupsen/logrus" +) + +const ( + dbusDest = "org.fedoraproject.FirewallD1" + dbusPath = "/org/fedoraproject/FirewallD1" + dbusRootIface = "org.fedoraproject.FirewallD1" + dbusZoneIface = "org.fedoraproject.FirewallD1.zone" + + errZoneAlreadySet = "ZONE_ALREADY_SET" + errAlreadyEnabled = "ALREADY_ENABLED" + errUnknownIface = "UNKNOWN_INTERFACE" + errNotEnabled = "NOT_ENABLED" + + // callTimeout bounds each individual DBus or firewall-cmd invocation. + // A fresh context is created for each call so a slow DBus probe can't + // exhaust the deadline before the firewall-cmd fallback gets to run. + callTimeout = 3 * time.Second +) + +var ( + errDBusUnavailable = errors.New("firewalld dbus unavailable") + + // trustLogOnce ensures the "added to trusted zone" message is logged at + // Info level only for the first successful add per process; repeat adds + // from other init paths are quieter. + trustLogOnce sync.Once + + parentCtxMu sync.RWMutex + parentCtx context.Context = context.Background() +) + +// SetParentContext installs a parent context whose cancellation aborts any +// in-flight TrustInterface call. It does not affect UntrustInterface, which +// always uses a fresh Background-rooted timeout so cleanup can still run +// during engine shutdown when the engine context is already cancelled. +func SetParentContext(ctx context.Context) { + parentCtxMu.Lock() + parentCtx = ctx + parentCtxMu.Unlock() +} + +func getParentContext() context.Context { + parentCtxMu.RLock() + defer parentCtxMu.RUnlock() + return parentCtx +} + +// TrustInterface places iface into firewalld's trusted zone if firewalld is +// running. It is idempotent and best-effort: errors are returned so callers +// can log, but a non-running firewalld is not an error. Only the first +// successful call per process logs at Info. Respects the parent context set +// via SetParentContext so startup-time cancellation unblocks it. +func TrustInterface(iface string) error { + parent := getParentContext() + if !isRunning(parent) { + return nil + } + if err := addTrusted(parent, iface); err != nil { + return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err) + } + trustLogOnce.Do(func() { + log.Infof("added %s to firewalld trusted zone", iface) + }) + log.Debugf("firewalld: ensured %s is in trusted zone", iface) + return nil +} + +// UntrustInterface removes iface from firewalld's trusted zone if firewalld +// is running. Idempotent. Uses a Background-rooted timeout so it still runs +// during shutdown after the engine context has been cancelled. +func UntrustInterface(iface string) error { + if !isRunning(context.Background()) { + return nil + } + if err := removeTrusted(context.Background(), iface); err != nil { + return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err) + } + return nil +} + +func newCallContext(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, callTimeout) +} + +func isRunning(parent context.Context) bool { + ctx, cancel := newCallContext(parent) + ok, err := isRunningDBus(ctx) + cancel() + if err == nil { + return ok + } + if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) { + ctx, cancel = newCallContext(parent) + defer cancel() + return isRunningCLI(ctx) + } + return false +} + +func addTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := addDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return addCLI(ctx, iface) +} + +func removeTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := removeDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return removeCLI(ctx, iface) +} + +func isRunningDBus(ctx context.Context) (bool, error) { + conn, err := dbus.SystemBus() + if err != nil { + return false, fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + var zone string + if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil { + return false, fmt.Errorf("firewalld getDefaultZone: %w", err) + } + return true, nil +} + +func isRunningCLI(ctx context.Context) bool { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return false + } + return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil +} + +func addDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errAlreadyEnabled) { + return nil + } + + if dbusErrContains(call.Err, errZoneAlreadySet) { + move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface) + if move.Err != nil { + return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err) + } + return nil + } + + return fmt.Errorf("firewalld addInterface: %w", call.Err) +} + +func removeDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) { + return nil + } + + return fmt.Errorf("firewalld removeInterface: %w", call.Err) +} + +func addCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + // --change-interface (no --permanent) binds the interface for the + // current runtime only; we do not want membership to persist across + // reboots because netbird re-asserts it on every startup. + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface, + ).CombinedOutput() + if err != nil { + return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out))) + } + return nil +} + +func removeCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface, + ).CombinedOutput() + if err != nil { + msg := strings.TrimSpace(string(out)) + if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) { + return nil + } + return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg) + } + return nil +} + +func dbusErrContains(err error, code string) bool { + if err == nil { + return false + } + var de dbus.Error + if errors.As(err, &de) { + for _, b := range de.Body { + if s, ok := b.(string); ok && strings.Contains(s, code) { + return true + } + } + } + return strings.Contains(err.Error(), code) +} diff --git a/client/firewall/firewalld/firewalld_linux_test.go b/client/firewall/firewalld/firewalld_linux_test.go new file mode 100644 index 000000000..d812745fc --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux_test.go @@ -0,0 +1,49 @@ +//go:build linux + +package firewalld + +import ( + "errors" + "testing" + + "github.com/godbus/dbus/v5" +) + +func TestDBusErrContains(t *testing.T) { + tests := []struct { + name string + err error + code string + want bool + }{ + {"nil error", nil, errZoneAlreadySet, false}, + {"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true}, + {"plain error miss", errors.New("something else"), errZoneAlreadySet, false}, + { + "dbus.Error body match", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}}, + errZoneAlreadySet, + true, + }, + { + "dbus.Error body miss", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}}, + errAlreadyEnabled, + false, + }, + { + "dbus.Error non-string body falls back to Error()", + dbus.Error{Name: "x", Body: []any{123}}, + "x", + true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := dbusErrContains(tc.err, tc.code) + if got != tc.want { + t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want) + } + }) + } +} diff --git a/client/firewall/firewalld/firewalld_other.go b/client/firewall/firewalld/firewalld_other.go new file mode 100644 index 000000000..cfa28221d --- /dev/null +++ b/client/firewall/firewalld/firewalld_other.go @@ -0,0 +1,25 @@ +//go:build !linux + +package firewalld + +import "context" + +// SetParentContext is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func SetParentContext(context.Context) { + // intentionally empty: firewalld is a Linux-only daemon +} + +// TrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func TrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} + +// UntrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func UntrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 4f8864d9a..696537dd8 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -124,6 +125,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Warnf("raw table not available, notrack rules will be disabled: %v", err) } + // Trust after all fatal init steps so a later failure doesn't leave the + // interface in firewalld's trusted zone without a corresponding Close. + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + // persist early to ensure cleanup of chains go func() { if err := stateManager.PersistState(context.Background()); err != nil { @@ -349,6 +356,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } + // Appending to merr intentionally blocks DeleteState below so ShutdownState + // stays persisted and the crash-recovery path retries firewalld cleanup. + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + // attempt to delete state only if all other operations succeeded if merr == nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil { @@ -372,6 +385,11 @@ func (m *Manager) AllowNetbird() error { merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err)) } } + + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 68b2f2b1a..fa81056dc 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -422,6 +423,10 @@ func (m *Manager) AllowNetbird() error { return fmt.Errorf("flush allow input netbird rules: %w", err) } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nil } diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 530db5e82..1a07084c7 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" @@ -40,6 +41,8 @@ const ( chainNameForward = "FORWARD" chainNameMangleForward = "netbird-mangle-forward" + firewalldTableName = "firewalld" + userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptInputRule = "inputaccept" @@ -137,6 +140,10 @@ func (r *router) Reset() error { merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } + if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + if err := r.removeNatPreroutingRules(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) } @@ -284,6 +291,10 @@ func (r *router) createContainers() error { log.Errorf("failed to add accept rules for the forward chain: %s", err) } + if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to refresh rules: %s", err) } @@ -1395,6 +1406,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } + // Skip firewalld-owned chains. Firewalld creates its chains with the + // NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM. + // We delegate acceptance to firewalld by trusting the interface instead. + if chain.Table.Name == firewalldTableName { + return false + } + // Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat) if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) { return false diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 6a6533344..b120cdf12 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,9 @@ package uspfilter import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/firewalld" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { if m.nativeFirewall != nil { return m.nativeFirewall.Close(stateManager) } + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to untrust interface in firewalld: %v", err) + } return nil } @@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error { if m.nativeFirewall != nil { return m.nativeFirewall.AllowNetbird() } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } return nil } diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go index 7296953db..9c06eb3f7 100644 --- a/client/firewall/uspfilter/common/iface.go +++ b/client/firewall/uspfilter/common/iface.go @@ -9,6 +9,7 @@ import ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { + Name() string SetFilter(device.PacketFilter) error Address() wgaddr.Address GetWGDevice() *wgdevice.Device diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 713c9656f..f19c4bb56 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger()) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() type IFaceMock struct { + NameFunc func() string SetFilterFunc func(device.PacketFilter) error AddressFunc func() wgaddr.Address GetWGDeviceFunc func() *wgdevice.Device GetDeviceFunc func() *device.FilteredDevice } +func (i *IFaceMock) Name() string { + if i.NameFunc == nil { + return "wgtest" + } + return i.NameFunc() +} + func (i *IFaceMock) GetWGDevice() *wgdevice.Device { if i.GetWGDeviceFunc == nil { return nil diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go index 1fdd955c9..f49e68508 100644 --- a/client/iface/bind/ice_bind_test.go +++ b/client/iface/bind/ice_bind_test.go @@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { ipv6Count++ } - assert.Equal(t, packetsPerFamily, ipv4Count) - assert.Equal(t, packetsPerFamily, ipv6Count) + // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The + // routing-correctness checks above are the real assertions; the counts + // are a sanity bound to catch a totally silent path. + minDelivered := packetsPerFamily * 80 / 100 + assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold") + assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold") } func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { diff --git a/client/internal/debug/upload_test.go b/client/internal/debug/upload_test.go index e833c196d..f224b8d3f 100644 --- a/client/internal/debug/upload_test.go +++ b/client/internal/debug/upload_test.go @@ -3,10 +3,12 @@ package debug import ( "context" "errors" + "net" "net/http" "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" @@ -19,8 +21,10 @@ func TestUpload(t *testing.T) { t.Skip("Skipping upload test on docker ci") } testDir := t.TempDir() - testURL := "http://localhost:8080" + addr := reserveLoopbackPort(t) + testURL := "http://" + addr t.Setenv("SERVER_URL", testURL) + t.Setenv("SERVER_ADDRESS", addr) t.Setenv("STORE_DIR", testDir) srv := server.NewServer() go func() { @@ -33,6 +37,7 @@ func TestUpload(t *testing.T) { t.Errorf("Failed to stop server: %v", err) } }) + waitForServer(t, addr) file := filepath.Join(t.TempDir(), "tmpfile") fileContent := []byte("test file content") @@ -47,3 +52,30 @@ func TestUpload(t *testing.T) { require.NoError(t, err) require.Equal(t, fileContent, createdFileContent) } + +// reserveLoopbackPort binds an ephemeral port on loopback to learn a free +// address, then releases it so the server under test can rebind. The close/ +// rebind window is racy in theory; on loopback with a kernel-assigned port +// it's essentially never contended in practice. +func reserveLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func waitForServer(t *testing.T, addr string) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + _ = c.Close() + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("server did not start listening on %s in time", addr) +} diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 8dacb4e51..50ba74c0c 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -13,6 +13,7 @@ import ( const ( defaultResolvConfPath = "/etc/resolv.conf" + nsswitchConfPath = "/etc/nsswitch.conf" ) type resolvConf struct { diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 6fbdedc59..57e7722d4 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,7 +1,10 @@ package dns import ( + "context" "fmt" + "math" + "net" "slices" "strconv" "strings" @@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() { } func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + c.dispatch(w, r, math.MaxInt) +} + +// dispatch routes a DNS request through the chain, skipping handlers with +// priority > maxPriority. Shared by ServeDNS and ResolveInternal. +func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) { if len(r.Question) == 0 { return } @@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // Try handlers in priority order for _, entry := range handlers { + if entry.Priority > maxPriority { + continue + } if !c.isHandlerMatch(qname, entry) { continue } @@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q cw.response.Len(), meta, time.Since(startTime)) } +// ResolveInternal runs an in-process DNS query against the chain, skipping any +// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt +// cache refresher) that must bypass themselves to avoid loops. Honors ctx +// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own +// (bounded by the invoked handler's internal timeout). +func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) { + if len(r.Question) == 0 { + return nil, fmt.Errorf("empty question") + } + + base := &internalResponseWriter{} + done := make(chan struct{}) + go func() { + c.dispatch(base, r, maxPriority) + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + // Prefer a completed response if dispatch finished concurrently with cancellation. + select { + case <-done: + default: + return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + } + } + + if base.response == nil || base.response.Rcode == dns.RcodeRefused { + return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d", + strings.ToLower(r.Question[0].Name), maxPriority) + } + return base.response, nil +} + +// HasRootHandlerAtOrBelow reports whether any "." handler is registered at +// priority ≤ maxPriority. +func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, h := range c.handlers { + if h.Pattern == "." && h.Priority <= maxPriority { + return true + } + } + return false +} + func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { switch { case entry.Pattern == ".": @@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { } } } + +// internalResponseWriter captures a dns.Msg for in-process chain queries. +type internalResponseWriter struct { + response *dns.Msg +} + +func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil } +func (w *internalResponseWriter) LocalAddr() net.Addr { return nil } +func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil } + +// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg +// still surface their answer to ResolveInternal. +func (w *internalResponseWriter) Write(p []byte) (int, error) { + msg := new(dns.Msg) + if err := msg.Unpack(p); err != nil { + return 0, err + } + w.response = msg + return len(p), nil +} + +func (w *internalResponseWriter) Close() error { return nil } +func (w *internalResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly is part of dns.ResponseWriter. +func (w *internalResponseWriter) TsigTimersOnly(bool) { + // no-op: in-process queries carry no TSIG state. +} + +// Hijack is part of dns.ResponseWriter. +func (w *internalResponseWriter) Hijack() { + // no-op: in-process queries have no underlying connection to hand off. +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index fa9525069..034a760dc 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,11 +1,15 @@ package dns_test import ( + "context" + "net" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" @@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { }) } } + +// answeringHandler writes a fixed A record to ack the query. Used to verify +// which handler ResolveInternal dispatches to. +type answeringHandler struct { + name string + ip string +} + +func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + _ = w.WriteMsg(resp) +} + +func (h *answeringHandler) String() string { return h.name } + +func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) { + chain := nbdns.NewHandlerChain() + + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + low := &answeringHandler{name: "low", ip: "10.0.0.2"} + + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + chain.AddHandler("example.com.", low, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + a, ok := resp.Answer[0].(*dns.A) + assert.True(t, ok) + assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream") +} + +func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) { + chain := nbdns.NewHandlerChain() + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + _, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.Error(t, err, "no handler at or below maxPriority should error") +} + +// rawWriteHandler packs a response and calls ResponseWriter.Write directly +// (instead of WriteMsg), exercising the internalResponseWriter.Write path. +type rawWriteHandler struct { + ip string +} + +func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + packed, err := resp.Pack() + if err != nil { + return + } + _, _ = w.Write(packed) +} + +func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) { + chain := nbdns.NewHandlerChain() + chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Answer, 1) + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer") +} + +func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) { + chain := nbdns.NewHandlerChain() + _, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream) + assert.Error(t, err) +} + +// hangingHandler blocks indefinitely until closed, simulating a wedged upstream. +type hangingHandler struct { + block chan struct{} +} + +func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + <-h.block + resp := &dns.Msg{} + resp.SetReply(r) + _ = w.WriteMsg(resp) +} + +func (h *hangingHandler) String() string { return "hangingHandler" } + +func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &hangingHandler{block: make(chan struct{})} + defer close(h.block) + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline") +} + +func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &answeringHandler{name: "h", ip: "10.0.0.1"} + + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain") + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count") + + chain.AddHandler(".", h, nbdns.PriorityMgmtCache) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded") + + chain.AddHandler(".", h, nbdns.PriorityDefault) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included") + + chain.RemoveHandler(".", nbdns.PriorityDefault) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) + + // Primary nsgroup case: root handler lands at PriorityUpstream. + chain.AddHandler(".", h, nbdns.PriorityUpstream) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included") + chain.RemoveHandler(".", nbdns.PriorityUpstream) + + // Fallback case: original /etc/resolv.conf entries land at PriorityFallback. + chain.AddHandler(".", h, nbdns.PriorityFallback) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included") + chain.RemoveHandler(".", nbdns.PriorityFallback) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) +} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 422fed4e5..d7301d725 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -46,12 +46,12 @@ type restoreHostManager interface { } func newHostManager(wgInterface string) (hostManager, error) { - osManager, err := getOSDNSManagerType() + osManager, reason, err := getOSDNSManagerType() if err != nil { return nil, fmt.Errorf("get os dns manager type: %w", err) } - log.Infof("System DNS manager discovered: %s", osManager) + log.Infof("System DNS manager discovered: %s (%s)", osManager, reason) mgr, err := newHostManagerFromType(wgInterface, osManager) // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value if err != nil { @@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor } } -func getOSDNSManagerType() (osManagerType, error) { +func getOSDNSManagerType() (osManagerType, string, error) { + resolved := isSystemdResolvedRunning() + nss := isLibnssResolveUsed() + stub := checkStub() + + // Prefer systemd-resolved whenever it owns libc resolution, regardless of + // who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups + // that go through nss-resolve, and in foreign mode they can loop back + // through resolved as an upstream. + if resolved && (nss || stub) { + return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil + } + + mgr, reason, rejected, err := scanResolvConfHeader() + if err != nil { + return 0, "", err + } + if reason != "" { + return mgr, reason, nil + } + + fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub) + if len(rejected) > 0 { + fallback += "; rejected: " + strings.Join(rejected, ", ") + } + return fileManager, fallback, nil +} + +// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the +// matching manager. If reason is empty the caller should pick file mode and +// use rejected for diagnostics. +func scanResolvConfHeader() (osManagerType, string, []string, error) { file, err := os.Open(defaultResolvConfPath) if err != nil { - return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) + return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) } defer func() { - if err := file.Close(); err != nil { - log.Errorf("close file %s: %s", defaultResolvConfPath, err) + if cerr := file.Close(); cerr != nil { + log.Errorf("close file %s: %s", defaultResolvConfPath, cerr) } }() + var rejected []string scanner := bufio.NewScanner(file) for scanner.Scan() { text := scanner.Text() @@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) { continue } if text[0] != '#' { - return fileManager, nil + break } - if strings.Contains(text, fileGeneratedResolvConfContentHeader) { - return netbirdManager, nil - } - if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { - return networkManager, nil - } - if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { - if checkStub() { - return systemdManager, nil - } else { - return fileManager, nil - } - } - if strings.Contains(text, "resolvconf") { - if isSystemdResolveConfMode() { - return systemdManager, nil - } - - return resolvConfManager, nil + if mgr, reason, rej := matchResolvConfHeader(text); reason != "" { + return mgr, reason, nil, nil + } else if rej != "" { + rejected = append(rejected, rej) } } if err := scanner.Err(); err != nil && err != io.EOF { - return 0, fmt.Errorf("scan: %w", err) + return 0, "", nil, fmt.Errorf("scan: %w", err) } - - return fileManager, nil + return 0, "", rejected, nil } -// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager. +// matchResolvConfHeader inspects a single comment line. Returns either a +// definitive (manager, reason) or a non-empty rejected diagnostic. +func matchResolvConfHeader(text string) (osManagerType, string, string) { + if strings.Contains(text, fileGeneratedResolvConfContentHeader) { + return netbirdManager, "netbird-managed resolv.conf header detected", "" + } + if strings.Contains(text, "NetworkManager") { + if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { + return networkManager, "NetworkManager header + supported version on dbus", "" + } + return 0, "", "NetworkManager header (no dbus or unsupported version)" + } + if strings.Contains(text, "resolvconf") { + if isSystemdResolveConfMode() { + return systemdManager, "resolvconf header in systemd-resolved compatibility mode", "" + } + return resolvConfManager, "resolvconf header detected", "" + } + return 0, "", "" +} + +// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed +// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping +// into file mode while resolved is active. func checkStub() bool { rConf, err := parseDefaultResolvConf() if err != nil { - log.Warnf("failed to parse resolv conf: %s", err) + log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err) return true } @@ -139,3 +178,36 @@ func checkStub() bool { return false } + +// isLibnssResolveUsed reports whether nss-resolve is listed before dns on +// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are +// delegated to systemd-resolved regardless of /etc/resolv.conf. +func isLibnssResolveUsed() bool { + bs, err := os.ReadFile(nsswitchConfPath) + if err != nil { + log.Debugf("read %s: %v", nsswitchConfPath, err) + return false + } + return parseNsswitchResolveAhead(bs) +} + +func parseNsswitchResolveAhead(data []byte) bool { + for _, line := range strings.Split(string(data), "\n") { + if i := strings.IndexByte(line, '#'); i >= 0 { + line = line[:i] + } + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "hosts:" { + continue + } + for _, module := range fields[1:] { + switch module { + case "dns": + return false + case "resolve": + return true + } + } + } + return false +} diff --git a/client/internal/dns/host_unix_test.go b/client/internal/dns/host_unix_test.go new file mode 100644 index 000000000..e936281d3 --- /dev/null +++ b/client/internal/dns/host_unix_test.go @@ -0,0 +1,76 @@ +//go:build (linux && !android) || freebsd + +package dns + +import "testing" + +func TestParseNsswitchResolveAhead(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + { + name: "resolve before dns with action token", + in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n", + want: true, + }, + { + name: "dns before resolve", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n", + want: false, + }, + { + name: "debian default with only dns", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n", + want: false, + }, + { + name: "neither resolve nor dns", + in: "hosts: files myhostname\n", + want: false, + }, + { + name: "no hosts line", + in: "passwd: files systemd\ngroup: files systemd\n", + want: false, + }, + { + name: "empty", + in: "", + want: false, + }, + { + name: "comments and blank lines ignored", + in: "# comment\n\n# another\nhosts: resolve dns\n", + want: true, + }, + { + name: "trailing inline comment", + in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n", + want: true, + }, + { + name: "hosts token must be the first field", + in: " hosts: resolve dns\n", + want: true, + }, + { + name: "other db line mentioning resolve is ignored", + in: "networks: resolve\nhosts: dns\n", + want: false, + }, + { + name: "only resolve, no dns", + in: "hosts: files resolve\n", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want { + t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 314af51d9..988e427fb 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,40 +2,83 @@ package mgmt import ( "context" + "errors" "fmt" "net" - "net/netip" "net/url" + "os" + "slices" "strings" "sync" + "sync/atomic" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/shared/management/domain" ) -const dnsTimeout = 5 * time.Second +const ( + dnsTimeout = 5 * time.Second + defaultTTL = 300 * time.Second + refreshBackoff = 30 * time.Second -// Resolver caches critical NetBird infrastructure domains + // envMgmtCacheTTL overrides defaultTTL for integration/dev testing. + envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" +) + +// ChainResolver lets the cache refresh stale entries through the DNS handler +// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the +// system resolver. +type ChainResolver interface { + ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) + HasRootHandlerAtOrBelow(maxPriority int) bool +} + +// cachedRecord holds DNS records plus timestamps used for TTL refresh. +// records and cachedAt are set at construction and treated as immutable; +// lastFailedRefresh and consecFailures are mutable and must be accessed under +// Resolver.mutex. +type cachedRecord struct { + records []dns.RR + cachedAt time.Time + lastFailedRefresh time.Time + consecFailures int +} + +// Resolver caches critical NetBird infrastructure domains. +// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex. type Resolver struct { - records map[dns.Question][]dns.RR + records map[dns.Question]*cachedRecord mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex -} -type ipsResponse struct { - ips []netip.Addr - err error + chain ChainResolver + chainMaxPriority int + refreshGroup singleflight.Group + + // refreshing tracks questions whose refresh is running via the OS + // fallback path. A ServeDNS hit for a question in this map indicates + // the OS resolver routed the recursive query back to us (loop). Only + // the OS path arms this so chain-path refreshes don't produce false + // positives. The atomic bool is CAS-flipped once per refresh to + // throttle the warning log. + refreshing map[dns.Question]*atomic.Bool + + cacheTTL time.Duration } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question][]dns.RR), + records: make(map[dns.Question]*cachedRecord), + refreshing: make(map[dns.Question]*atomic.Bool), + cacheTTL: resolveCacheTTL(), } } @@ -44,7 +87,19 @@ func (m *Resolver) String() string { return "MgmtCacheResolver" } -// ServeDNS implements dns.Handler interface. +// SetChainResolver wires the handler chain used to refresh stale cache entries. +// maxPriority caps which handlers may answer refresh queries (typically +// PriorityUpstream, so upstream/default/fallback handlers are consulted and +// mgmt/route/local handlers are skipped). +func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { + m.mutex.Lock() + m.chain = chain + m.chainMaxPriority = maxPriority + m.mutex.Unlock() +} + +// ServeDNS serves cached A/AAAA records. Stale entries are returned +// immediately and refreshed asynchronously (stale-while-revalidate). func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - records, found := m.records[question] + cached, found := m.records[question] + inflight := m.refreshing[question] + var shouldRefresh bool + if found { + stale := time.Since(cached.cachedAt) > m.cacheTTL + inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff + shouldRefresh = stale && !inBackoff + } m.mutex.RUnlock() if !found { @@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + if inflight != nil && inflight.CompareAndSwap(false, true) { + log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)", + question.Name) + } + + // Skip scheduling a refresh goroutine if one is already inflight for + // this question; singleflight would dedup anyway but skipping avoids + // a parked goroutine per stale hit under bursty load. + if shouldRefresh && inflight == nil { + m.scheduleRefresh(question, cached) + } + resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - - resp.Answer = append(resp.Answer, records...) + resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { } } -// AddDomain manually adds a domain to cache by resolving it. +// AddDomain resolves a domain and stores its A/AAAA records in the cache. +// A family that resolves NODATA (nil err, zero records) evicts any stale +// entry for that qtype. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := lookupIPWithExtraTimeout(ctx, d) - if err != nil { - return err + aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) + + if errA != nil && errAAAA != nil { + return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - var aRecords, aaaaRecords []dns.RR - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) + if len(aRecords) == 0 && len(aaaaRecords) == 0 { + if err := errors.Join(errA, errAAAA); err != nil { + return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) } + return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } + now := time.Now() m.mutex.Lock() + defer m.mutex.Unlock() - if len(aRecords) > 0 { - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - m.records[aQuestion] = aRecords - } + m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) + m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) - if len(aaaaRecords) > 0 { - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - m.records[aaaaQuestion] = aaaaRecords - } - - m.mutex.Unlock() - - log.Debugf("added domain=%s with %d A records and %d AAAA records", + log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", d.SafeString(), len(aRecords), len(aaaaRecords)) return nil } -func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { - log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) - defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) - resultChan := make(chan *ipsResponse, 1) +// applyFamilyRecords writes records, evicts on NODATA, leaves the cache +// untouched on error. Caller holds m.mutex. +func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { + q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} + switch { + case len(records) > 0: + m.records[q] = &cachedRecord{records: records, cachedAt: now} + case err == nil: + delete(m.records, q) + } +} - go func() { - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) - resultChan <- &ipsResponse{ - err: err, - ips: ips, +// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per +// unique in-flight key; bursty stale hits share its channel. expected is the +// cachedRecord pointer observed by the caller; the refresh only mutates the +// cache if that pointer is still the one stored, so a stale in-flight refresh +// can't clobber a newer entry written by AddDomain or a competing refresh. +func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { + key := question.Name + "|" + dns.TypeToString[question.Qtype] + _ = m.refreshGroup.DoChan(key, func() (any, error) { + return nil, m.refreshQuestion(question, expected) + }) +} + +// refreshQuestion replaces the cached records on success, or marks the entry +// failed (arming the backoff) on failure. While this runs, ServeDNS can detect +// a resolver loop by spotting a query for this same question arriving on us. +// expected pins the cache entry observed at schedule time; mutations only apply +// if m.records[question] still points at it. +func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + m.markRefreshFailed(question, expected) + return fmt.Errorf("parse domain: %w", err) + } + + records, err := m.lookupRecords(ctx, d, question) + if err != nil { + fails := m.markRefreshFailed(question, expected) + logf := log.Warnf + if fails == 0 || fails > 1 { + logf = log.Debugf } - }() - - var resp *ipsResponse - - select { - case <-time.After(dnsTimeout + time.Millisecond*500): - log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) - return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) - case <-ctx.Done(): - return nil, ctx.Err() - case resp = <-resultChan: + logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", + d.SafeString(), dns.TypeToString[question.Qtype], err, fails) + return err } - if resp.err != nil { - return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. + if len(records) == 0 { + m.mutex.Lock() + if m.records[question] == expected { + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.mutex.Unlock() + log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil } - return resp.ips, nil + + now := time.Now() + m.mutex.Lock() + if m.records[question] != expected { + m.mutex.Unlock() + log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.records[question] = &cachedRecord{records: records, cachedAt: now} + m.mutex.Unlock() + + log.Infof("refreshed mgmt cache domain=%s type=%s", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil +} + +func (m *Resolver) markRefreshing(question dns.Question) { + m.mutex.Lock() + m.refreshing[question] = &atomic.Bool{} + m.mutex.Unlock() +} + +func (m *Resolver) clearRefreshing(question dns.Question) { + m.mutex.Lock() + delete(m.refreshing, question) + m.mutex.Unlock() +} + +// markRefreshFailed arms the backoff and returns the new consecutive-failure +// count so callers can downgrade subsequent failure logs to debug. +func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { + m.mutex.Lock() + defer m.mutex.Unlock() + c, ok := m.records[question] + if !ok || c != expected { + return 0 + } + c.lastFailedRefresh = time.Now() + c.consecFailures++ + return c.consecFailures +} + +// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let +// callers tell records, NODATA (nil err, no records), and failure apart. +func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) + return + } + + // TODO: drop once every supported OS registers a fallback resolver. Safe + // today: no root handler at priority ≤ PriorityUpstream means NetBird is + // not the system resolver, so net.DefaultResolver will not loop back. + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) + return +} + +// lookupRecords resolves a single record type via chain or OS. The OS branch +// arms the loop detector for the duration of its call so that ServeDNS can +// spot the OS resolver routing the recursive query back to us. +func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) + } + + // TODO: drop once every supported OS registers a fallback resolver. + m.markRefreshing(q) + defer m.clearRefreshing(q) + + return m.osLookup(ctx, d, q.Name, q.Qtype) +} + +// lookupViaChain resolves via the handler chain and rewrites each RR to use +// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache +// target-owned records or upstream TTLs. NODATA returns (nil, nil). +func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { + msg := &dns.Msg{} + msg.SetQuestion(dnsName, qtype) + msg.RecursionDesired = true + + resp, err := chain.ResolveInternal(ctx, msg, maxPriority) + if err != nil { + return nil, fmt.Errorf("chain resolve: %w", err) + } + if resp == nil { + return nil, fmt.Errorf("chain resolve returned nil response") + } + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) + } + + ttl := uint32(m.cacheTTL.Seconds()) + owners := cnameOwners(dnsName, resp.Answer) + var filtered []dns.RR + for _, rr := range resp.Answer { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { + continue + } + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { + continue + } + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) + } + } + return filtered, nil +} + +// osLookup resolves a single family via net.DefaultResolver using resutil, +// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA +// returns (nil, nil). +func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) + } + + log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + + result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + if result.Rcode == dns.RcodeSuccess { + return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil + } + + if result.Err != nil { + return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) + } + return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) +} + +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func (m *Resolver) responseTTL(cachedAt time.Time) uint32 { + remaining := m.cacheTTL - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - delete(m.records, aQuestion) - - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - delete(m.records, aaaaQuestion) + qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} + qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} + delete(m.records, qA) + delete(m.records, qAAAA) + delete(m.refreshing, qA) + delete(m.refreshing, qAAAA) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil @@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } + +// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non +// A/AAAA records return nil. +func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + return &cp + case *dns.AAAA: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + return &cp + } + return nil +} + +// cloneRecordsWithTTL clones A/AAAA records preserving their owner and +// stamping ttl so the response shares no memory with the cached slice. +func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { + out := make([]dns.RR, 0, len(records)) + for _, rr := range records { + if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { + out = append(out, cp) + } + } + return out +} + +// cnameOwners returns dnsName plus every target reachable by following CNAMEs +// in answer, iterating until fixed point so out-of-order chains resolve. +func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { + owners := map[string]bool{dnsName: true} + for { + added := false + for _, rr := range answer { + cname, ok := rr.(*dns.CNAME) + if !ok { + continue + } + name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) + if !owners[name] { + continue + } + target := strings.ToLower(dns.Fqdn(cname.Target)) + if !owners[target] { + owners[target] = true + added = true + } + } + if !added { + return owners + } + } +} + +// resolveCacheTTL reads the cache TTL override env var; invalid or empty +// values fall back to defaultTTL. Called once per Resolver from NewResolver. +func resolveCacheTTL() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +} diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go new file mode 100644 index 000000000..9faa5a0b8 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_refresh_test.go @@ -0,0 +1,408 @@ +package mgmt + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type fakeChain struct { + mu sync.Mutex + calls map[string]int + answers map[string][]dns.RR + err error + hasRoot bool + onLookup func() +} + +func newFakeChain() *fakeChain { + return &fakeChain{ + calls: map[string]int{}, + answers: map[string][]dns.RR{}, + hasRoot: true, + } +} + +func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.hasRoot +} + +func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) { + f.mu.Lock() + q := msg.Question[0] + key := q.Name + "|" + dns.TypeToString[q.Qtype] + f.calls[key]++ + answers := f.answers[key] + err := f.err + onLookup := f.onLookup + f.mu.Unlock() + + if onLookup != nil { + onLookup() + } + if err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Answer = answers + return resp, nil +} + +func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) { + f.mu.Lock() + defer f.mu.Unlock() + key := name + "|" + dns.TypeToString[qtype] + hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60} + switch qtype { + case dns.TypeA: + f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}} + case dns.TypeAAAA: + f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}} + } +} + +func (f *fakeChain) callCount(name string, qtype uint16) int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls[name+"|"+dns.TypeToString[qtype]] +} + +// waitFor polls the predicate until it returns true or the deadline passes. +func waitFor(t *testing.T, d time.Duration, fn func() bool) { + t.Helper() + deadline := time.Now().Add(d) + for time.Now().Before(deadline) { + if fn() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("condition not met within %s", d) +} + +func queryA(t *testing.T, r *Resolver, name string) *dns.Msg { + t.Helper() + msg := new(dns.Msg) + msg.SetQuestion(name, dns.TypeA) + w := &test.MockResponseWriter{} + r.ServeDNS(w, msg) + return w.GetLastResponse() +} + +func firstA(t *testing.T, resp *dns.Msg) string { + t.Helper() + require.NotNil(t, resp) + require.Greater(t, len(resp.Answer), 0, "expected at least one answer") + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok, "expected A record") + return a.A.String() +} + +func TestResolver_CacheTTLGatesRefresh(t *testing.T) { + // Same cached entry age, different cacheTTL values: the shorter TTL must + // trigger a background refresh, the longer one must not. Proves that the + // per-Resolver cacheTTL field actually drives the stale decision. + cachedAt := time.Now().Add(-100 * time.Millisecond) + + newRec := func() *cachedRecord { + return &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: cachedAt, + } + } + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + + t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = 10 * time.Millisecond + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount(q.Name, dns.TypeA) >= 1 + }) + }) + + t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = time.Hour + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh") + }) +} + +func TestResolver_ServeFresh_NoRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), // fresh + } + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(20 * time.Millisecond) + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh") +} + +func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), // stale + } + + // First query: serves stale immediately. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1 + }) + + // Next query should now return the refreshed IP. + waitFor(t, time.Second, func() bool { + resp := queryA(t, r, "mgmt.example.com.") + return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2" + }) +} + +func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + + var inflight atomic.Int32 + var maxInflight atomic.Int32 + chain.onLookup = func() { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + prev := maxInflight.Load() + if cur <= prev || maxInflight.CompareAndSwap(prev, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide + } + + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + queryA(t, r, "mgmt.example.com.") + }() + } + wg.Wait() + + waitFor(t, 2*time.Second, func() bool { + return inflight.Load() == 0 + }) + + calls := chain.callCount("mgmt.example.com.", dns.TypeA) + assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls) + assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently") +} + +func TestResolver_RefreshFailureArmsBackoff(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.err = errors.New("boom") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + // First stale hit triggers a refresh attempt that fails. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) == 1 + }) + waitFor(t, time.Second, func() bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + c, ok := r.records[q] + return ok && !c.lastFailedRefresh.IsZero() + }) + + // Subsequent stale hits within backoff window should not schedule more refreshes. + for i := 0; i < 10; i++ { + queryA(t, r, "mgmt.example.com.") + } + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes") +} + +func TestResolver_NoRootHandler_SkipsChain(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.hasRoot = false + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + // With hasRoot=false the chain must not be consulted. Use a short + // deadline so the OS fallback returns quickly without waiting on a + // real network call in CI. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + _, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.") + + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), + "chain must not be used when no root handler is registered at the bound priority") +} + +func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) { + // ServeDNS being invoked for a question while a refresh for that question + // is inflight indicates a resolver loop (OS resolver sent the recursive + // query back to us). The inflightRefresh.loopLoggedOnce flag must be set. + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + // Simulate an inflight refresh. + r.markRefreshing(q) + defer r.clearRefreshing(q) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries") + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + require.NotNil(t, inflight) + assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed") +} + +func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + r.markRefreshing(q) + defer r.clearRefreshing(q) + + // Multiple ServeDNS calls during the same refresh must not re-set the flag + // (CompareAndSwap from false -> true returns true only on the first call). + for range 5 { + queryA(t, r, "mgmt.example.com.") + } + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + assert.True(t, inflight.Load()) +} + +func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + queryA(t, r, "mgmt.example.com.") + + r.mutex.RLock() + _, ok := r.refreshing[q] + r.mutex.RUnlock() + assert.False(t, ok, "no refresh inflight means no loop tracking") +} + +func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2") + r.SetChainResolver(chain, 50) + + require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com"))) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.2", firstA(t, resp)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA)) +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 9e8a746f3..276cbba0a 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) { assert.False(t, resolver.MatchSubdomains()) } +func TestResolveCacheTTL(t *testing.T) { + tests := []struct { + name string + value string + want time.Duration + }{ + {"unset falls back to default", "", defaultTTL}, + {"valid duration", "45s", 45 * time.Second}, + {"valid minutes", "2m", 2 * time.Minute}, + {"malformed falls back to default", "not-a-duration", defaultTTL}, + {"zero falls back to default", "0s", defaultTTL}, + {"negative falls back to default", "-5s", defaultTTL}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envMgmtCacheTTL, tc.value) + got := resolveCacheTTL() + assert.Equal(t, tc.want, got, "parsed TTL should match") + }) + } +} + +func TestNewResolver_CacheTTLFromEnv(t *testing.T) { + t.Setenv(envMgmtCacheTTL, "7s") + r := NewResolver() + assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env") +} + +func TestResolver_ResponseTTL(t *testing.T) { + now := time.Now() + tests := []struct { + name string + cacheTTL time.Duration + cachedAt time.Time + wantMin uint32 + wantMax uint32 + }{ + {"fresh entry returns full TTL", 60 * time.Second, now, 59, 60}, + {"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31}, + {"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0}, + {"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := &Resolver{cacheTTL: tc.cacheTTL} + got := r.responseTTL(tc.cachedAt) + assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin") + assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax") + }) + } +} + func TestResolver_ExtractDomainFromURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f7865047b..d4f54dec5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -212,6 +212,7 @@ func newDefaultServer( ctx, stop := context.WithCancel(ctx) mgmtCacheResolver := mgmt.NewResolver() + mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream) defaultServer := &DefaultServer{ ctx: ctx, diff --git a/client/internal/engine.go b/client/internal/engine.go index 375b96f04..78cd235cf 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -26,6 +26,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" @@ -574,7 +575,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.connMgr.Start(e.ctx) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) - e.srWatcher.Start() + e.srWatcher.Start(peer.IsForceRelayed()) e.receiveSignalEvents() e.receiveManagementEvents() @@ -608,6 +609,8 @@ func (e *Engine) createFirewall() error { return nil } + firewalld.SetParentContext(e.ctx) + var err error e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8d1585b3f..1e416bfe7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) - relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) - if err != nil { - return err + forceRelay := IsForceRelayed() + if !forceRelay { + relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() + workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) + if err != nil { + return err + } + conn.workerICE = workerICE } - conn.workerICE = workerICE conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) - if !isForceRelayed() { + if !forceRelay { conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } @@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgWatcherCancel() } conn.workerRelay.CloseConn() - conn.workerICE.Close() + if conn.workerICE != nil { + conn.workerICE.Close() + } if conn.wgProxyRelay != nil { err := conn.wgProxyRelay.CloseConn() @@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) { // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { conn.dumpState.RemoteCandidate() - conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + if conn.workerICE != nil { + conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + } } // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established @@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusConnecting } -func (conn *Conn) isConnectedOnAllWay() (connected bool) { - // would be better to protect this with a mutex, but it could cause deadlock with Close function - +// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports. +// +// The result is a tri-state: +// - ConnStatusConnected: all available transports are up +// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting +// - ConnStatusDisconnected: no working transport +func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) { defer func() { - if !connected { + if status == guard.ConnStatusDisconnected { conn.logTraceConnState() } }() - // For JS platform: only relay connection is supported - if runtime.GOOS == "js" { - return conn.statusRelay.Get() == worker.StatusConnected + iceWorkerCreated := conn.workerICE != nil + + var iceInProgress bool + if iceWorkerCreated { + iceInProgress = conn.workerICE.InProgress() } - // For non-JS platforms: check ICE connection status - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { - return false - } - - // If relay is supported with peer, it must also be connected - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == worker.StatusDisconnected { - return false - } - } - - return true + return evalConnStatus(connStatusInputs{ + forceRelay: IsForceRelayed(), + peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(), + relayConnected: conn.statusRelay.Get() == worker.StatusConnected, + remoteSupportsICE: conn.handshaker.RemoteICESupported(), + iceWorkerCreated: iceWorkerCreated, + iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected, + iceInProgress: iceInProgress, + }) } func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { @@ -926,3 +935,43 @@ func isController(config ConnConfig) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } + +func evalConnStatus(in connStatusInputs) guard.ConnStatus { + // "Relay up and needed" — the peer uses relay and the transport is connected. + relayUsedAndUp := in.peerUsesRelay && in.relayConnected + + // Force-relay mode: ICE never runs. Relay is the only transport and must be up. + if in.forceRelay { + return boolToConnStatus(relayUsedAndUp) + } + + // Remote peer doesn't support ICE, or we haven't created the worker yet: + // relay is the only possible transport. + if !in.remoteSupportsICE || !in.iceWorkerCreated { + return boolToConnStatus(relayUsedAndUp) + } + + // ICE counts as "up" when the status is anything other than Disconnected, OR + // when a negotiation is currently in progress (so we don't spam offers while one is in flight). + iceUp := in.iceStatusConnecting || in.iceInProgress + + // Relay side is acceptable if the peer doesn't rely on relay, or relay is connected. + relayOK := !in.peerUsesRelay || in.relayConnected + + switch { + case iceUp && relayOK: + return guard.ConnStatusConnected + case relayUsedAndUp: + // Relay is up but ICE is down — partially connected. + return guard.ConnStatusPartiallyConnected + default: + return guard.ConnStatusDisconnected + } +} + +func boolToConnStatus(connected bool) guard.ConnStatus { + if connected { + return guard.ConnStatusConnected + } + return guard.ConnStatusDisconnected +} diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 73acc5ef5..b43e245f3 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -13,6 +13,20 @@ const ( StatusConnected ) +// connStatusInputs is the primitive-valued snapshot of the state that drives the +// tri-state connection classification. Extracted so the decision logic can be unit-tested +// without constructing full Worker/Handshaker objects. +type connStatusInputs struct { + forceRelay bool // NB_FORCE_RELAY or JS/WASM + peerUsesRelay bool // remote peer advertises relay support AND local has relay + relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay) + remoteSupportsICE bool // remote peer sent ICE credentials + iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode) + iceStatusConnecting bool // statusICE is anything other than Disconnected + iceInProgress bool // a negotiation is currently in flight +} + + // ConnStatus describe the status of a peer's connection type ConnStatus int32 diff --git a/client/internal/peer/conn_status_eval_test.go b/client/internal/peer/conn_status_eval_test.go new file mode 100644 index 000000000..66393cafe --- /dev/null +++ b/client/internal/peer/conn_status_eval_test.go @@ -0,0 +1,201 @@ +package peer + +import ( + "testing" + + "github.com/netbirdio/netbird/client/internal/peer/guard" +) + +func TestEvalConnStatus_ForceRelay(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "force relay, peer uses relay, relay up", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "force relay, peer uses relay, relay down", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: false, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "force relay, peer does NOT use relay - disconnected forever", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: false, + relayConnected: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_ICEUnavailable(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "remote does not support ICE, peer uses relay, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer uses relay, relay down", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE worker not yet created, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: true, + iceWorkerCreated: false, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer does not use relay", + in: connStatusInputs{ + peerUsesRelay: false, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_FullyAvailable(t *testing.T) { + base := connStatusInputs{ + remoteSupportsICE: true, + iceWorkerCreated: true, + } + + tests := []struct { + name string + mutator func(*connStatusInputs) + want guard.ConnStatus + }{ + { + name: "ICE connected, relay connected, peer uses relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE connected, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE InProgress only, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.iceStatusConnecting = false + in.iceInProgress = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE down, relay up, peer uses relay -> partial", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusPartiallyConnected, + }, + { + name: "ICE down, peer does NOT use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = false + in.iceStatusConnecting = true + }, + // relayOK = false (peer uses relay but it's down), iceUp = true + // first switch arm fails (relayOK false), relayUsedAndUp = false (relay down), + // falls into default: Disconnected. + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE down, relay up but peer does not use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = true // not actually used since peer doesn't rely on it + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + in := base + tc.mutator(&in) + if got := evalConnStatus(in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in) + } + }) + } +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 7f500c410..b4ba9ad7b 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -10,7 +10,7 @@ const ( EnvKeyNBForceRelay = "NB_FORCE_RELAY" ) -func isForceRelayed() bool { +func IsForceRelayed() bool { if runtime.GOOS == "js" { return true } diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index d93403730..2e5efbcc5 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -8,7 +8,19 @@ import ( log "github.com/sirupsen/logrus" ) -type isConnectedFunc func() bool +// ConnStatus represents the connection state as seen by the guard. +type ConnStatus int + +const ( + // ConnStatusDisconnected means neither ICE nor Relay is connected. + ConnStatusDisconnected ConnStatus = iota + // ConnStatusPartiallyConnected means Relay is connected but ICE is not. + ConnStatusPartiallyConnected + // ConnStatusConnected means all required connections are established. + ConnStatusConnected +) + +type connStatusFunc func() ConnStatus // Guard is responsible for the reconnection logic. // It will trigger to send an offer to the peer then has connection issues. @@ -20,14 +32,14 @@ type isConnectedFunc func() bool // - ICE candidate changes type Guard struct { log *log.Entry - isConnectedOnAllWay isConnectedFunc + isConnectedOnAllWay connStatusFunc timeout time.Duration srWatcher *SRWatcher relayedConnDisconnected chan struct{} iCEConnDisconnected chan struct{} } -func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { +func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { return &Guard{ log: log, isConnectedOnAllWay: isConnectedFn, @@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() { } } -// reconnectLoopWithRetry periodically check the connection status. -// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity. +// +// Behavior depends on the connection state reported by isConnectedOnAllWay: +// - Connected: no action, the peer is fully reachable. +// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling +// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all. +// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches +// to one attempt per hour. This limits signaling traffic when relay already provides connectivity. +// +// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry +// counter and backoff ticker, giving ICE a fresh chance after network conditions change. func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { srReconnectedChan := g.srWatcher.NewListener() defer g.srWatcher.RemoveListener(srReconnectedChan) @@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { tickerChannel := ticker.C + iceState := &iceRetryState{log: g.log} + defer iceState.reset() + for { select { - case t := <-tickerChannel: - if t.IsZero() { - g.log.Infof("retry timed out, stop periodic offer sending") - // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop - tickerChannel = make(<-chan time.Time) - continue + case <-tickerChannel: + switch g.isConnectedOnAllWay() { + case ConnStatusConnected: + // all good, nothing to do + case ConnStatusDisconnected: + callback() + case ConnStatusPartiallyConnected: + if iceState.shouldRetry() { + callback() + } else { + iceState.enterHourlyMode() + ticker.Stop() + tickerChannel = iceState.hourlyC() + } } - if !g.isConnectedOnAllWay() { - callback() - } case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-srReconnectedChan: g.log.Debugf("has network changes, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-ctx.Done(): g.log.Debugf("context is done, stop reconnect loop") @@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } -func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 0.1, diff --git a/client/internal/peer/guard/ice_retry_state.go b/client/internal/peer/guard/ice_retry_state.go new file mode 100644 index 000000000..01dc1bf2d --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state.go @@ -0,0 +1,61 @@ +package guard + +import ( + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // maxICERetries is the maximum number of ICE offer attempts when relay is connected + maxICERetries = 3 + // iceRetryInterval is the periodic retry interval after ICE retries are exhausted + iceRetryInterval = 1 * time.Hour +) + +// iceRetryState tracks the limited ICE retry attempts when relay is already connected. +// After maxICERetries attempts it switches to a periodic hourly retry. +type iceRetryState struct { + log *log.Entry + retries int + hourly *time.Ticker +} + +func (s *iceRetryState) reset() { + s.retries = 0 + if s.hourly != nil { + s.hourly.Stop() + s.hourly = nil + } +} + +// shouldRetry reports whether the caller should send another ICE offer on this tick. +// Returns false when the per-cycle retry budget is exhausted and the caller must switch +// to the hourly ticker via enterHourlyMode + hourlyC. +func (s *iceRetryState) shouldRetry() bool { + if s.hourly != nil { + s.log.Debugf("hourly ICE retry attempt") + return true + } + + s.retries++ + if s.retries <= maxICERetries { + s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries) + return true + } + + return false +} + +// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false. +func (s *iceRetryState) enterHourlyMode() { + s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries) + s.hourly = time.NewTicker(iceRetryInterval) +} + +func (s *iceRetryState) hourlyC() <-chan time.Time { + if s.hourly == nil { + return nil + } + return s.hourly.C +} diff --git a/client/internal/peer/guard/ice_retry_state_test.go b/client/internal/peer/guard/ice_retry_state_test.go new file mode 100644 index 000000000..6a5b5a76f --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state_test.go @@ -0,0 +1,103 @@ +package guard + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func newTestRetryState() *iceRetryState { + return &iceRetryState{log: log.NewEntry(log.StandardLogger())} +} + +func TestICERetryState_AllowsInitialBudget(t *testing.T) { + s := newTestRetryState() + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries) + } + } +} + +func TestICERetryState_ExhaustsAfterBudget(t *testing.T) { + s := newTestRetryState() + + for i := 0; i < maxICERetries; i++ { + _ = s.shouldRetry() + } + + if s.shouldRetry() { + t.Fatalf("shouldRetry returned true after budget exhausted, want false") + } +} + +func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) { + s := newTestRetryState() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode") + } +} + +func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + + s.enterHourlyMode() + defer s.reset() + + if s.hourlyC() == nil { + t.Fatalf("hourlyC returned nil after enterHourlyMode") + } +} + +func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) { + s := newTestRetryState() + s.enterHourlyMode() + defer s.reset() + + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false in hourly mode, want true") + } + + // Subsequent calls also return true — we keep retrying on each hourly tick. + if !s.shouldRetry() { + t.Fatalf("second shouldRetry returned false in hourly mode, want true") + } +} + +func TestICERetryState_ResetRestoresBudget(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + s.enterHourlyMode() + + s.reset() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel after reset") + } + if s.retries != 0 { + t.Fatalf("retries = %d after reset, want 0", s.retries) + } + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i) + } + } +} + +func TestICERetryState_ResetIsIdempotent(t *testing.T) { + s := newTestRetryState() + s.reset() + s.reset() // second call must not panic or re-stop a nil ticker + + if s.hourlyC() != nil { + t.Fatalf("hourlyC non-nil after double reset") + } +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 6f4f5ad4f..0befd7438 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove return srw } -func (w *SRWatcher) Start() { +func (w *SRWatcher) Start(disableICEMonitor bool) { w.mu.Lock() defer w.mu.Unlock() @@ -50,8 +50,10 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) - go iceMonitor.Start(ctx, w.onICEChanged) + if !disableICEMonitor { + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) + go iceMonitor.Start(ctx, w.onICEChanged) + } w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 9b50cecd1..741dfce60 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" log "github.com/sirupsen/logrus" @@ -43,6 +44,10 @@ type OfferAnswer struct { SessionID *ICESessionID } +func (o *OfferAnswer) hasICECredentials() bool { + return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != "" +} + type Handshaker struct { mu sync.Mutex log *log.Entry @@ -59,6 +64,10 @@ type Handshaker struct { relayListener *AsyncOfferListener iceListener func(remoteOfferAnswer *OfferAnswer) + // remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers. + // When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses. + remoteICESupported atomic.Bool + // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection @@ -66,7 +75,7 @@ type Handshaker struct { } func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { - return &Handshaker{ + h := &Handshaker{ log: log, config: config, signaler: signaler, @@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } + // assume remote supports ICE until we learn otherwise from received offers + h.remoteICESupported.Store(ice != nil) + return h +} + +func (h *Handshaker) RemoteICESupported() bool { + return h.remoteICESupported.Load() } func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { @@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } @@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } case remoteOfferAnswer := <-h.remoteAnswerCh: - h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): @@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error { } func (h *Handshaker) buildOfferAnswer() OfferAnswer { - uFrag, pwd := h.ice.GetLocalUserCredentials() - sid := h.ice.SessionID() answer := OfferAnswer{ - IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassAddr: h.config.RosenpassConfig.Addr, - SessionID: &sid, + } + + if h.ice != nil && h.RemoteICESupported() { + uFrag, pwd := h.ice.GetLocalUserCredentials() + sid := h.ice.SessionID() + answer.IceCredentials = IceCredentials{uFrag, pwd} + answer.SessionID = &sid } if addr, err := h.relay.RelayInstanceAddress(); err == nil { @@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer { return answer } + +func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) { + hasICE := offer.hasICECredentials() + prev := h.remoteICESupported.Swap(hasICE) + if prev != hasICE { + if hasICE { + h.log.Infof("remote peer started sending ICE credentials") + } else { + h.log.Infof("remote peer stopped sending ICE credentials") + if h.ice != nil { + h.ice.Close() + } + } + } +} diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index b28906625..f6eb87cca 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool { // SignalOfferAnswer signals either an offer or an answer to remote peer func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { - sessionIDBytes, err := offerAnswer.SessionID.Bytes() - if err != nil { - log.Warnf("failed to get session ID bytes: %v", err) + var sessionIDBytes []byte + if offerAnswer.SessionID != nil { + var err error + sessionIDBytes, err = offerAnswer.SessionID.Bytes() + if err != nil { + log.Warnf("failed to get session ID bytes: %v", err) + } } msg, err := signal.MarshalCredential( s.wgPrivateKey, diff --git a/combined/config.yaml.example b/combined/config.yaml.example index dce658d89..af85b0477 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -119,6 +119,8 @@ server: # Reverse proxy settings (optional) # reverseProxy: - # trustedHTTPProxies: [] - # trustedHTTPProxiesCount: 0 - # trustedPeers: [] + # trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"]) + # trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies) + # trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"]) + # accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely). + # accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value. diff --git a/flow/client/client_test.go b/flow/client/client_test.go index 55157acbc..c8f5f4af4 100644 --- a/flow/client/client_test.go +++ b/flow/client/client_test.go @@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) require.NoError(t, err) + + // Cleanups run LIFO: the goroutine-drain registered here runs after Close below, + // which is when Receive has actually returned. Without this, the Receive goroutine + // can outlive the test and call t.Logf after teardown, panicking. + receiveDone := make(chan struct{}) + t.Cleanup(func() { + select { + case <-receiveDone: + case <-time.After(2 * time.Second): + t.Error("Receive goroutine did not exit after Close") + } + }) t.Cleanup(func() { err := client.Close() assert.NoError(t, err, "failed to close flow") @@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { receivedAfterReconnect := make(chan struct{}) go func() { + defer close(receiveDone) err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { if msg.IsInitiator || len(msg.EventId) == 0 { return nil diff --git a/go.mod b/go.mod index 5172b1a78..1b5861a37 100644 --- a/go.mod +++ b/go.mod @@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 + +replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index 9293ce73b..3772946e1 100644 --- a/go.sum +++ b/go.sum @@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= @@ -449,6 +447,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= +github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= +github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 08da48264..2a3f840b4 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -472,7 +472,7 @@ start_services_and_show_instructions() { if [[ "$ENABLE_CROWDSEC" == "true" ]]; then echo "Registering CrowdSec bouncer..." local cs_retries=0 - while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do + while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do cs_retries=$((cs_retries + 1)) if [[ $cs_retries -ge 30 ]]; then echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 63be672e6..8106380f2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/management-integrations/integrations" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" @@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, authHeader) - if err != nil { + if err := m.checkJWTFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) case "token": - request, err := m.checkPATFromRequest(r, authHeader) - if err != nil { + if err := m.checkPATFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) // Check if it's a status error, otherwise default to Unauthorized if _, ok := status.FromError(err); !ok { @@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { util.WriteError(r.Context(), err, w) return } - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) default: util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return @@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } ctx := r.Context() userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) if err != nil { - return r, err + return err } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { @@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := m.ensureAccount(ctx, userAuth) if err != nil { - return r, err + return err } if userAuth.AccountId != accountId { @@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) if err != nil { - return r, err + return err } err = m.syncUserJWTGroups(ctx, userAuth) @@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] _, err = m.getUserFromUserAuth(ctx, userAuth) if err != nil { log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) - return r, err + return err } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } if m.patUsageTracker != nil { @@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] if m.rateLimiter != nil && !isTerraformRequest(r) { if !m.rateLimiter.Allow(token) { - return r, status.Errorf(status.TooManyRequests, "too many requests") + return status.Errorf(status.TooManyRequests, "too many requests") } } ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { - return r, fmt.Errorf("invalid Token: %w", err) + return fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.GetExpirationDate()) { - return r, fmt.Errorf("token expired") + return fmt.Errorf("token expired") } err = m.authManager.MarkPATUsed(ctx, pat.ID) if err != nil { - return r, err + return err } userAuth := auth.UserAuth{ @@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] } } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } func isTerraformRequest(r *http.Request) bool { diff --git a/management/server/policy.go b/management/server/policy.go index 3e84c3d10..48297ca11 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,6 +5,7 @@ import ( _ "embed" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var isUpdate = policy.ID != "" var updateAccountPeers bool var action = activity.PolicyAdded + var unchanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { - return err - } - - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) if err != nil { return err } - saveFunc := transaction.CreatePolicy if isUpdate { - action = activity.PolicyUpdated - saveFunc = transaction.SavePolicy - } + if policy.Equal(existingPolicy) { + logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID) + unchanged = true + return nil + } - if err = saveFunc(ctx, policy); err != nil { - return err + action = activity.PolicyUpdated + + updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) + if err != nil { + return err + } + + if err = transaction.SavePolicy(ctx, policy); err != nil { + return err + } + } else { + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) + if err != nil { + return err + } + + if err = transaction.CreatePolicy(ctx, policy); err != nil { + return err + } } return transaction.IncrementNetworkSerial(ctx, accountID) @@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return nil, err } + if unchanged { + return policy, nil + } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { @@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { - if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if err != nil { - return false, err - } - - if !policy.Enabled && !existingPolicy.Enabled { - return false, nil - } - - for _, rule := range existingPolicy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { +// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil } } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) +} + +func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) { + if !policy.Enabled && !existingPolicy.Enabled { + return false, nil + } + + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + for _, rule := range policy.Rules { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil @@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } -// validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { +// validatePolicy validates the policy and its rules. For updates it returns +// the existing policy loaded from the store so callers can avoid a second read. +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) { + var existingPolicy *types.Policy if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + var err error + existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { - return err + return nil, err } // TODO: Refactor to support multiple rules per policy @@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri for _, rule := range policy.Rules { if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) } } } else { @@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { - return err + return nil, err } postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { - return err + return nil, err } for i, rule := range policy.Rules { @@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return nil + return existingPolicy, nil } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index 28e8457e2..e48e6d64a 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { } }) - h.ServeHTTP(w, r.WithContext(ctx)) + // Hold on to req so auth's in-place ctx update is visible after ServeHTTP. + req := r.WithContext(ctx) + h.ServeHTTP(w, req) close(handlerDone) - userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) - if err == nil { - if userAuth.AccountId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) - } - if userAuth.UserId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) - } - } + ctx = req.Context() if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) diff --git a/management/server/types/policy.go b/management/server/types/policy.go index d4e1a8816..d410aec8d 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy { return c } +func (p *Policy) Equal(other *Policy) bool { + if p == nil || other == nil { + return p == other + } + + if p.ID != other.ID || + p.AccountID != other.AccountID || + p.Name != other.Name || + p.Description != other.Description || + p.Enabled != other.Enabled { + return false + } + + if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) { + return false + } + + if len(p.Rules) != len(other.Rules) { + return false + } + + otherRules := make(map[string]*PolicyRule, len(other.Rules)) + for _, r := range other.Rules { + otherRules[r.ID] = r + } + for _, r := range p.Rules { + otherRule, ok := otherRules[r.ID] + if !ok { + return false + } + if !r.Equal(otherRule) { + return false + } + } + + return true +} + // EventMeta returns activity event meta related to this policy func (p *Policy) EventMeta() map[string]any { return map[string]any{"name": p.Name} diff --git a/management/server/types/policy_test.go b/management/server/types/policy_test.go new file mode 100644 index 000000000..b1d7aabc2 --- /dev/null +++ b/management/server/types/policy_test.go @@ -0,0 +1,193 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRules(t *testing.T) { + a := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + b := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRuleCount(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc3", "pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2", "pc3"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentPostureChecks(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentScalarFields(t *testing.T) { + base := Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Description: "desc", + Enabled: true, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Description = "changed" + assert.False(t, base.Equal(&other)) +} + +func TestPolicyEqual_NilCases(t *testing.T) { + var a *Policy + var b *Policy + assert.True(t, a.Equal(b)) + + a = &Policy{ID: "pol1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyEqual_RulesMismatchByID(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_FullScenario(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc2", "pc1"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g2", "g1"}, + Destinations: []string{"g4", "g3"}, + Ports: []string{"443", "80", "8080"}, + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + }, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc1", "pc2"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g1", "g2"}, + Destinations: []string{"g3", "g4"}, + Ports: []string{"80", "8080", "443"}, + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + }, + }, + } + assert.True(t, a.Equal(b)) +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index bb75dd555..52c494a6a 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -1,6 +1,8 @@ package types import ( + "slices" + "github.com/netbirdio/netbird/shared/management/proto" ) @@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule { } return rule } + +func (pm *PolicyRule) Equal(other *PolicyRule) bool { + if pm == nil || other == nil { + return pm == other + } + + if pm.ID != other.ID || + pm.PolicyID != other.PolicyID || + pm.Name != other.Name || + pm.Description != other.Description || + pm.Enabled != other.Enabled || + pm.Action != other.Action || + pm.Bidirectional != other.Bidirectional || + pm.Protocol != other.Protocol || + pm.SourceResource != other.SourceResource || + pm.DestinationResource != other.DestinationResource || + pm.AuthorizedUser != other.AuthorizedUser { + return false + } + + if !stringSlicesEqualUnordered(pm.Sources, other.Sources) { + return false + } + if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) { + return false + } + if !stringSlicesEqualUnordered(pm.Ports, other.Ports) { + return false + } + if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) { + return false + } + if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) { + return false + } + + return true +} + +func stringSlicesEqualUnordered(a, b []string) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + sorted1 := make([]string, len(a)) + sorted2 := make([]string, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.Sort(sorted1) + slices.Sort(sorted2) + return slices.Equal(sorted1, sorted2) +} + +func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + cmp := func(x, y RulePortRange) int { + if x.Start != y.Start { + if x.Start < y.Start { + return -1 + } + return 1 + } + if x.End != y.End { + if x.End < y.End { + return -1 + } + return 1 + } + return 0 + } + sorted1 := make([]RulePortRange, len(a)) + sorted2 := make([]RulePortRange, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.SortFunc(sorted1, cmp) + slices.SortFunc(sorted2, cmp) + return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool { + return x.Start == y.Start && x.End == y.End + }) +} + +func authorizedGroupsEqual(a, b map[string][]string) bool { + if len(a) != len(b) { + return false + } + for k, va := range a { + vb, ok := b[k] + if !ok { + return false + } + if !stringSlicesEqualUnordered(va, vb) { + return false + } + } + return true +} diff --git a/management/server/types/policyrule_test.go b/management/server/types/policyrule_test.go new file mode 100644 index 000000000..816e72abb --- /dev/null +++ b/management/server/types/policyrule_test.go @@ -0,0 +1,194 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80", "22"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"22", "443", "80"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPorts(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "22"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2", "g3"}, + Destinations: []string{"g4", "g5"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g3", "g1", "g2"}, + Destinations: []string{"g5", "g4"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentSources(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 443}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1", "u2", "u3"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u3", "u1", "u2"}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g2": {"u1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) { + base := PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Name: "test", + Description: "desc", + Enabled: true, + Action: PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Action = PolicyTrafficActionDrop + assert.False(t, base.Equal(&other)) + + other = base + other.Protocol = PolicyRuleProtocolUDP + assert.False(t, base.Equal(&other)) +} + +func TestPolicyRuleEqual_NilCases(t *testing.T) { + var a *PolicyRule + var b *PolicyRule + assert.True(t, a.Equal(b)) + + a = &PolicyRule{ID: "rule1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyRuleEqual_EmptySlices(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{}, + Sources: nil, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: nil, + Sources: []string{}, + } + assert.True(t, a.Equal(b)) +} + diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index dcabd2f0f..96ac6670b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -30,6 +30,8 @@ import ( const ConnectTimeout = 10 * time.Second +const healthCheckTimeout = 5 * time.Second + const ( // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) // for the management client connection. Value is in bytes. @@ -532,7 +534,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5368b57a2..d0f598dd7 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -23,6 +23,8 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +const healthCheckTimeout = 5 * time.Second + // ConnStateNotifier is a wrapper interface of the status recorder type ConnStateNotifier interface { MarkSignalDisconnected(error) @@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.Send(ctx, &proto.EncryptedMessage{ Key: c.key.PublicKey().String(),