Merge branch 'main' of github.com:netbirdio/netbird into feat/pat-creation-on-setup

This commit is contained in:
jnfrati
2026-04-27 16:59:12 +02:00
115 changed files with 5448 additions and 6084 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.2" SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3" GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT): $(GOLANGCI_LINT):
@echo "Installing golangci-lint..." @echo "Installing golangci-lint..."
@mkdir -p ./bin @mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push) # Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT) lint: $(GOLANGCI_LINT)

View File

@@ -8,6 +8,7 @@ import (
"os" "os"
"slices" "slices"
"sync" "sync"
"time"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -15,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -26,6 +28,7 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -68,7 +71,30 @@ type Client struct {
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
} }
func (c *Client) RenewTun(fd int) error { func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil { cc := c.getConnectClient()
if cc == nil {
return fmt.Errorf("engine not running") return fmt.Errorf("engine not running")
} }
e := c.connectClient.Engine() e := cc.Engine()
if e == nil { if e == nil {
return fmt.Errorf("engine not initialized") return fmt.Errorf("engine not initialized")
} }
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd) return e.RenewTun(fd)
} }
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level // SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() { func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
} }
func (c *Client) Networks() *NetworkArray { func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil { cc := c.getConnectClient()
if cc == nil {
log.Error("not connected") log.Error("not connected")
return nil return nil
} }
engine := c.connectClient.Engine() engine := cc.Engine()
if engine == nil { if engine == nil {
log.Error("could not get engine") log.Error("could not get engine")
return nil return nil
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
} }
func (c *Client) getRouteManager() (routemanager.Manager, error) { func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient client := c.getConnectClient()
if client == nil { if client == nil {
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }

View File

@@ -7,4 +7,5 @@ package android
type PlatformFiles interface { type PlatformFiles interface {
ConfigurationFilePath() string ConfigurationFilePath() string
StateFilePath() string StateFilePath() string
CacheDir() string
} }

View File

@@ -0,0 +1,11 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -0,0 +1,260 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -0,0 +1,49 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,25 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err) log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
} }
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
} }
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded // attempt to delete state only if all other operations succeeded
if merr == nil { if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -14,6 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
return fmt.Errorf("flush allow input netbird rules: %w", err) return fmt.Errorf("flush allow input netbird rules: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -19,6 +19,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
@@ -40,6 +41,8 @@ const (
chainNameForward = "FORWARD" chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward" chainNameMangleForward = "netbird-mangle-forward"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept" userDataAcceptInputRule = "inputaccept"
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
} }
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil { if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
} }
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err) log.Errorf("failed to refresh rules: %s", err)
} }
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false return false
} }
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip all iptables-managed tables in the ip family // Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false return false

View File

@@ -3,6 +3,9 @@
package uspfilter package uspfilter
import ( import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager) return m.nativeFirewall.Close(stateManager)
} }
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil return nil
} }
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird() return m.nativeFirewall.AllowNetbird()
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -9,6 +9,7 @@ import (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() wgaddr.Address Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device

View File

@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
NameFunc func() string
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
func (i *IFaceMock) Name() string {
if i.NameFunc == nil {
return "wgtest"
}
return i.NameFunc()
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device { func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil { if i.GetWGDeviceFunc == nil {
return nil return nil

View File

@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++ ipv6Count++
} }
assert.Equal(t, packetsPerFamily, ipv4Count) // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
assert.Equal(t, packetsPerFamily, ipv6Count) // routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
} }
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface // Close closes the tunnel interface
func (w *WGIface) Close() error { func (w *WGIface) Close() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error var result *multierror.Error
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }
if err := w.tun.Close(); err != nil { // Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
} }

View File

@@ -0,0 +1,113 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
} }
if u.address.Network.Contains(a) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted) u.addrCache.Store(addr.String(), isRouted)
if isRouted { if isRouted {
// Extra log, as the error only shows up with ICE logging enabled // Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
} }
} }

View File

@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
stateFilePath string, stateFilePath string,
cacheDir string,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath, StateFilePath: stateFilePath,
TempDir: cacheDir,
} }
return c.run(mobileDependency, nil, "") return c.run(mobileDependency, nil, "")
} }
@@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager) c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,7 +16,6 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"slices"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -31,7 +30,6 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
) )
const readmeContent = `Netbird debug bundle const readmeContent = `Netbird debug bundle
@@ -234,6 +232,7 @@ type BundleGenerator struct {
statusRecorder *peer.Status statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse syncResponse *mgmProto.SyncResponse
logPath string logPath string
tempDir string
cpuProfile []byte cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter clientMetrics MetricsExporter
@@ -256,6 +255,7 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse SyncResponse *mgmProto.SyncResponse
LogPath string LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter ClientMetrics MetricsExporter
@@ -275,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder, statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse, syncResponse: deps.SyncResponse,
logPath: deps.LogPath, logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile, cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus, refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics, clientMetrics: deps.ClientMetrics,
@@ -287,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location. // Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) { func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
if err != nil { if err != nil {
return "", fmt.Errorf("create zip file: %w", err) return "", fmt.Errorf("create zip file: %w", err)
} }
@@ -373,15 +374,8 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err) log.Errorf("failed to add wg show output: %v", err)
} }
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { if err := g.addPlatformLog(); err != nil {
if err := g.addLogfile(); err != nil { log.Errorf("failed to add logs to debug bundle: %v", err)
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
} }
if err := g.addUpdateLogs(); err != nil { if err := g.addUpdateLogs(); err != nil {

View File

@@ -0,0 +1,41 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -3,10 +3,12 @@ package debug
import ( import (
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci") t.Skip("Skipping upload test on docker ci")
} }
testDir := t.TempDir() testDir := t.TempDir()
testURL := "http://localhost:8080" addr := reserveLoopbackPort(t)
testURL := "http://" + addr
t.Setenv("SERVER_URL", testURL) t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir) t.Setenv("STORE_DIR", testDir)
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err) t.Errorf("Failed to stop server: %v", err)
} }
}) })
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile") file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content") fileContent := []byte("test file content")
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent) require.Equal(t, fileContent, createdFileContent)
} }
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -13,6 +13,7 @@ import (
const ( const (
defaultResolvConfPath = "/etc/resolv.conf" defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
) )
type resolvConf struct { type resolvConf struct {

View File

@@ -1,7 +1,10 @@
package dns package dns
import ( import (
"context"
"fmt" "fmt"
"math"
"net"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
} }
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) { if !c.isHandlerMatch(qname, entry) {
continue continue
} }
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime)) cw.response.Len(), meta, time.Since(startTime))
} }
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch { switch {
case entry.Pattern == ".": case entry.Pattern == ".":
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
} }
} }
} }
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,11 +1,15 @@
package dns_test package dns_test
import ( import (
"context"
"net"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
}) })
} }
} }
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, reason, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err) return nil, fmt.Errorf("get os dns manager type: %w", err)
} }
log.Infof("System DNS manager discovered: %s", osManager) log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
mgr, err := newHostManagerFromType(wgInterface, osManager) mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil { if err != nil {
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
} }
} }
func getOSDNSManagerType() (osManagerType, error) { func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
file, err := os.Open(defaultResolvConfPath) file, err := os.Open(defaultResolvConfPath)
if err != nil { if err != nil {
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
} }
defer func() { defer func() {
if err := file.Close(); err != nil { if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err) log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
} }
}() }()
var rejected []string
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
continue continue
} }
if text[0] != '#' { if text[0] != '#' {
return fileManager, nil break
} }
if strings.Contains(text, fileGeneratedResolvConfContentHeader) { if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return netbirdManager, nil return mgr, reason, nil, nil
} } else if rej != "" {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { rejected = append(rejected, rej)
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
} }
} }
if err := scanner.Err(); err != nil && err != io.EOF { if err := scanner.Err(); err != nil && err != io.EOF {
return 0, fmt.Errorf("scan: %w", err) return 0, "", nil, fmt.Errorf("scan: %w", err)
} }
return 0, "", rejected, nil
return fileManager, nil
} }
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager. // matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
func checkStub() bool { func checkStub() bool {
rConf, err := parseDefaultResolvConf() rConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Warnf("failed to parse resolv conf: %s", err) log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
return true return true
} }
@@ -139,3 +178,36 @@ func checkStub() bool {
return false return false
} }
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,76 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,40 +2,83 @@ package mgmt
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"net/url" "net/url"
"os"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const dnsTimeout = 5 * time.Second const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
// Resolver caches critical NetBird infrastructure domains // envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct { type Resolver struct {
records map[dns.Question][]dns.RR records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex mutex sync.RWMutex
}
type ipsResponse struct { chain ChainResolver
ips []netip.Addr chainMaxPriority int
err error refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
} }
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question][]dns.RR), records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
} }
} }
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver" return "MgmtCacheResolver"
} }
// ServeDNS implements dns.Handler interface. // SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
m.continueToNext(w, r) m.continueToNext(w, r)
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
m.mutex.RLock() m.mutex.RLock()
records, found := m.records[question] cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
m.mutex.RUnlock() m.mutex.RUnlock()
if !found { if !found {
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetReply(r) resp.SetReply(r)
resp.Authoritative = false resp.Authoritative = false
resp.RecursionAvailable = true resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// AddDomain manually adds a domain to cache by resolving it. // AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout) ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel() defer cancel()
ips, err := lookupIPWithExtraTimeout(ctx, d) aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if err != nil {
return err if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
} }
var aRecords, aaaaRecords []dns.RR if len(aRecords) == 0 && len(aaaaRecords) == 0 {
for _, ip := range ips { if err := errors.Join(errA, errAAAA); err != nil {
if ip.Is4() { return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
} }
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
} }
now := time.Now()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
if len(aRecords) > 0 { m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
aQuestion := dns.Question{ m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
if len(aaaaRecords) > 0 { log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords)) d.SafeString(), len(aRecords), len(aaaaRecords))
return nil return nil
} }
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { // applyFamilyRecords writes records, evicts on NODATA, leaves the cache
log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) // untouched on error. Caller holds m.mutex.
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
resultChan := make(chan *ipsResponse, 1) q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
go func() { // scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) // unique in-flight key; bursty stale hits share its channel. expected is the
resultChan <- &ipsResponse{ // cachedRecord pointer observed by the caller; the refresh only mutates the
err: err, // cache if that pointer is still the one stored, so a stale in-flight refresh
ips: ips, // can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
} }
}() logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
var resp *ipsResponse return err
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
} }
if resp.err != nil { // NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
} }
return resp.ips, nil
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
} }
// PopulateFromConfig extracts and caches domains from the client configuration. // PopulateFromConfig extracts and caches domains from the client configuration.
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
aQuestion := dns.Question{ qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
Name: dnsName, qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
Qtype: dns.TypeA, delete(m.records, qA)
Qclass: dns.ClassINET, delete(m.records, qAAAA)
} delete(m.refreshing, qA)
delete(m.records, aQuestion) delete(m.refreshing, qAAAA)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString()) log.Debugf("removed domain=%s from cache", d.SafeString())
return nil return nil
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains return domains
} }
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -0,0 +1,408 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains()) assert.False(t, resolver.MatchSubdomains())
} }
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) { func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -212,6 +212,7 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver() mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,

View File

@@ -26,6 +26,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
@@ -140,6 +141,7 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config ProfileConfig *profilemanager.Config
LogPath string LogPath string
TempDir string
} }
// EngineServices holds the external service dependencies required by the Engine. // EngineServices holds the external service dependencies required by the Engine.
@@ -569,7 +571,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx) e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start() e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
@@ -603,6 +605,8 @@ func (e *Engine) createFirewall() error {
return nil return nil
} }
firewalld.SetParentContext(e.ctx)
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil { if err != nil {
@@ -1095,6 +1099,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder, StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: e.config.LogPath, LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics, ClientMetrics: e.clientMetrics,
RefreshStatus: func() { RefreshStatus: func() {
e.RunHealthProbes(true) e.RunHealthProbes(true)

View File

@@ -22,4 +22,8 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager DnsManager dns.IosDnsManager
FileDescriptor int32 FileDescriptor int32
StateFilePath string StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
} }

View File

@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() forceRelay := IsForceRelayed()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) if !forceRelay {
if err != nil { relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
return err workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
} }
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() { if !forceRelay {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
} }
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel() conn.wgWatcherCancel()
} }
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
conn.workerICE.Close() if conn.workerICE != nil {
conn.workerICE.Close()
}
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn() err := conn.wgProxyRelay.CloseConn()
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate() conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) if conn.workerICE != nil {
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
} }
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusConnecting return StatusConnecting
} }
func (conn *Conn) isConnectedOnAllWay() (connected bool) { // isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
// would be better to protect this with a mutex, but it could cause deadlock with Close function //
// The result is a tri-state:
// - ConnStatusConnected: all available transports are up
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
// - ConnStatusDisconnected: no working transport
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
defer func() { defer func() {
if !connected { if status == guard.ConnStatusDisconnected {
conn.logTraceConnState() conn.logTraceConnState()
} }
}() }()
// For JS platform: only relay connection is supported iceWorkerCreated := conn.workerICE != nil
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected var iceInProgress bool
if iceWorkerCreated {
iceInProgress = conn.workerICE.InProgress()
} }
// For non-JS platforms: check ICE connection status return evalConnStatus(connStatusInputs{
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { forceRelay: IsForceRelayed(),
return false peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
} relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
// If relay is supported with peer, it must also be connected iceWorkerCreated: iceWorkerCreated,
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
if conn.statusRelay.Get() == worker.StatusDisconnected { iceInProgress: iceInProgress,
return false })
}
}
return true
} }
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
// "Relay up and needed" — the peer uses relay and the transport is connected.
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
if in.forceRelay {
return boolToConnStatus(relayUsedAndUp)
}
// Remote peer doesn't support ICE, or we haven't created the worker yet:
// relay is the only possible transport.
if !in.remoteSupportsICE || !in.iceWorkerCreated {
return boolToConnStatus(relayUsedAndUp)
}
// ICE counts as "up" when the status is anything other than Disconnected, OR
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
iceUp := in.iceStatusConnecting || in.iceInProgress
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
relayOK := !in.peerUsesRelay || in.relayConnected
switch {
case iceUp && relayOK:
return guard.ConnStatusConnected
case relayUsedAndUp:
// Relay is up but ICE is down — partially connected.
return guard.ConnStatusPartiallyConnected
default:
return guard.ConnStatusDisconnected
}
}
func boolToConnStatus(connected bool) guard.ConnStatus {
if connected {
return guard.ConnStatusConnected
}
return guard.ConnStatusDisconnected
}

View File

@@ -13,6 +13,20 @@ const (
StatusConnected StatusConnected
) )
// connStatusInputs is the primitive-valued snapshot of the state that drives the
// tri-state connection classification. Extracted so the decision logic can be unit-tested
// without constructing full Worker/Handshaker objects.
type connStatusInputs struct {
forceRelay bool // NB_FORCE_RELAY or JS/WASM
peerUsesRelay bool // remote peer advertises relay support AND local has relay
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
remoteSupportsICE bool // remote peer sent ICE credentials
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
iceStatusConnecting bool // statusICE is anything other than Disconnected
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int32

View File

@@ -0,0 +1,201 @@
package peer
import (
"testing"
"github.com/netbirdio/netbird/client/internal/peer/guard"
)
func TestEvalConnStatus_ForceRelay(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "force relay, peer uses relay, relay up",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: true,
},
want: guard.ConnStatusConnected,
},
{
name: "force relay, peer uses relay, relay down",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: false,
},
want: guard.ConnStatusDisconnected,
},
{
name: "force relay, peer does NOT use relay - disconnected forever",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: false,
relayConnected: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "remote does not support ICE, peer uses relay, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer uses relay, relay down",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE worker not yet created, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: true,
iceWorkerCreated: false,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer does not use relay",
in: connStatusInputs{
peerUsesRelay: false,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
base := connStatusInputs{
remoteSupportsICE: true,
iceWorkerCreated: true,
}
tests := []struct {
name string
mutator func(*connStatusInputs)
want guard.ConnStatus
}{
{
name: "ICE connected, relay connected, peer uses relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE connected, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE InProgress only, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.iceStatusConnecting = false
in.iceInProgress = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE down, relay up, peer uses relay -> partial",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusPartiallyConnected,
},
{
name: "ICE down, peer does NOT use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = false
in.iceStatusConnecting = true
},
// relayOK = false (peer uses relay but it's down), iceUp = true
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
// falls into default: Disconnected.
want: guard.ConnStatusDisconnected,
},
{
name: "ICE down, relay up but peer does not use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = true // not actually used since peer doesn't rely on it
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
in := base
tc.mutator(&in)
if got := evalConnStatus(in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
}
})
}
}

View File

@@ -10,7 +10,7 @@ const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY" EnvKeyNBForceRelay = "NB_FORCE_RELAY"
) )
func isForceRelayed() bool { func IsForceRelayed() bool {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
return true return true
} }

View File

@@ -8,7 +8,19 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type isConnectedFunc func() bool // ConnStatus represents the connection state as seen by the guard.
type ConnStatus int
const (
// ConnStatusDisconnected means neither ICE nor Relay is connected.
ConnStatusDisconnected ConnStatus = iota
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
ConnStatusPartiallyConnected
// ConnStatusConnected means all required connections are established.
ConnStatusConnected
)
type connStatusFunc func() ConnStatus
// Guard is responsible for the reconnection logic. // Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues. // It will trigger to send an offer to the peer then has connection issues.
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
// - ICE candidate changes // - ICE candidate changes
type Guard struct { type Guard struct {
log *log.Entry log *log.Entry
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay connStatusFunc
timeout time.Duration timeout time.Duration
srWatcher *SRWatcher srWatcher *SRWatcher
relayedConnDisconnected chan struct{} relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{} iCEConnDisconnected chan struct{}
} }
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
log: log, log: log,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
} }
} }
// reconnectLoopWithRetry periodically check the connection status. // reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported //
// Behavior depends on the connection state reported by isConnectedOnAllWay:
// - Connected: no action, the peer is fully reachable.
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
//
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan) defer g.srWatcher.RemoveListener(srReconnectedChan)
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
tickerChannel := ticker.C tickerChannel := ticker.C
iceState := &iceRetryState{log: g.log}
defer iceState.reset()
for { for {
select { select {
case t := <-tickerChannel: case <-tickerChannel:
if t.IsZero() { switch g.isConnectedOnAllWay() {
g.log.Infof("retry timed out, stop periodic offer sending") case ConnStatusConnected:
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop // all good, nothing to do
tickerChannel = make(<-chan time.Time) case ConnStatusDisconnected:
continue callback()
case ConnStatusPartiallyConnected:
if iceState.shouldRetry() {
callback()
} else {
iceState.enterHourlyMode()
ticker.Stop()
tickerChannel = iceState.hourlyC()
}
} }
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-g.iCEConnDisconnected: case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker") g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-srReconnectedChan: case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker") g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-ctx.Done(): case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop") g.log.Debugf("context is done, stop reconnect loop")
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo) return backoff.NewTicker(bo)
} }
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1, RandomizationFactor: 0.1,

View File

@@ -0,0 +1,61 @@
package guard
import (
"time"
log "github.com/sirupsen/logrus"
)
const (
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
maxICERetries = 3
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
iceRetryInterval = 1 * time.Hour
)
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
// After maxICERetries attempts it switches to a periodic hourly retry.
type iceRetryState struct {
log *log.Entry
retries int
hourly *time.Ticker
}
func (s *iceRetryState) reset() {
s.retries = 0
if s.hourly != nil {
s.hourly.Stop()
s.hourly = nil
}
}
// shouldRetry reports whether the caller should send another ICE offer on this tick.
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
// to the hourly ticker via enterHourlyMode + hourlyC.
func (s *iceRetryState) shouldRetry() bool {
if s.hourly != nil {
s.log.Debugf("hourly ICE retry attempt")
return true
}
s.retries++
if s.retries <= maxICERetries {
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
return true
}
return false
}
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
func (s *iceRetryState) enterHourlyMode() {
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
s.hourly = time.NewTicker(iceRetryInterval)
}
func (s *iceRetryState) hourlyC() <-chan time.Time {
if s.hourly == nil {
return nil
}
return s.hourly.C
}

View File

@@ -0,0 +1,103 @@
package guard
import (
"testing"
log "github.com/sirupsen/logrus"
)
func newTestRetryState() *iceRetryState {
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
}
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
s := newTestRetryState()
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
}
}
}
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries; i++ {
_ = s.shouldRetry()
}
if s.shouldRetry() {
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
}
}
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
s := newTestRetryState()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
}
}
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
defer s.reset()
if s.hourlyC() == nil {
t.Fatalf("hourlyC returned nil after enterHourlyMode")
}
}
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
s := newTestRetryState()
s.enterHourlyMode()
defer s.reset()
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false in hourly mode, want true")
}
// Subsequent calls also return true — we keep retrying on each hourly tick.
if !s.shouldRetry() {
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
}
}
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
s.reset()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel after reset")
}
if s.retries != 0 {
t.Fatalf("retries = %d after reset, want 0", s.retries)
}
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
}
}
}
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
s := newTestRetryState()
s.reset()
s.reset() // second call must not panic or re-stop a nil ticker
if s.hourlyC() != nil {
t.Fatalf("hourlyC non-nil after double reset")
}
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw return srw
} }
func (w *SRWatcher) Start() { func (w *SRWatcher) Start(disableICEMonitor bool) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) if !disableICEMonitor {
go iceMonitor.Start(ctx, w.onICEChanged) iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
}
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"sync" "sync"
"sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -43,6 +44,10 @@ type OfferAnswer struct {
SessionID *ICESessionID SessionID *ICESessionID
} }
func (o *OfferAnswer) hasICECredentials() bool {
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
}
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
@@ -59,6 +64,10 @@ type Handshaker struct {
relayListener *AsyncOfferListener relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer) iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -66,7 +75,7 @@ type Handshaker struct {
} }
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
return &Handshaker{ h := &Handshaker{
log: log, log: log,
config: config, config: config,
signaler: signaler, signaler: signaler,
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer), remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer),
} }
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
} }
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts // Record signaling received for reconnection attempts
if h.metricsStages != nil { if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived() h.metricsStages.RecordSignalingReceived()
} }
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil { if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
} }
if h.iceListener != nil { if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer) h.iceListener(&remoteOfferAnswer)
} }
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue continue
} }
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts // Record signaling received for reconnection attempts
if h.metricsStages != nil { if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived() h.metricsStages.RecordSignalingReceived()
} }
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil { if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
} }
if h.iceListener != nil { if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer) h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
} }
func (h *Handshaker) buildOfferAnswer() OfferAnswer { func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{ answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort, WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(), Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr, RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid, }
if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
} }
if addr, err := h.relay.RelayInstanceAddress(); err == nil { if addr, err := h.relay.RelayInstanceAddress(); err == nil {
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
return answer return answer
} }
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.hasICECredentials()
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer // SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes() var sessionIDBytes []byte
if err != nil { if offerAnswer.SessionID != nil {
log.Warnf("failed to get session ID bytes: %v", err) var err error
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
} }
msg, err := signal.MarshalCredential( msg, err := signal.MarshalCredential(
s.wgPrivateKey, s.wgPrivateKey,

View File

@@ -0,0 +1,10 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -0,0 +1,241 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks" "github.com/netbirdio/netbird/client/net/hooks"
) )
@@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
@@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
} }
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting() if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
if err != nil { if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) { log.Errorf("Failed to get routes: %v", err)
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{} return false, netip.Prefix{}
} }

View File

@@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil return []netip.Prefix{}, nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM. // GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil return []DetailedRoute{}, nil

View File

@@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool { func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,10 +48,6 @@ func EnableIPForwarding() error {
return nil return nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) // GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) { func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS) log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,6 +25,9 @@ import (
const ( const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
) )
var routeProtoFlag int var routeProtoFlag int
@@ -41,26 +44,42 @@ func init() {
} }
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. // FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error { func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB() rib, err := retryFetchRIB()
if err != nil { if err != nil {
return fmt.Errorf("fetch routing table: %w", err) return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
} }
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil { if err != nil {
return fmt.Errorf("parse routing table: %w", err) return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
} }
var merr *multierror.Error
flushedCount := 0 flushedCount := 0
for _, msg := range msgs { for _, msg := range msgs {
@@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix) return fmt.Errorf("invalid prefix: %s", prefix)
} }
expBackOff := backoff.NewExponentialBackOff() msg, err := r.buildRouteMessage(action, prefix, nexthop)
expBackOff.InitialInterval = 50 * time.Millisecond if err != nil {
expBackOff.MaxInterval = 500 * time.Millisecond return fmt.Errorf("build route message: %w", err)
expBackOff.MaxElapsedTime = 1 * time.Second }
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil { if err := r.writeRouteMessage(msg, routeBudget); err != nil {
a := "add" a := "add"
if action == unix.RTM_DELETE { if action == unix.RTM_DELETE {
a = "remove" a = "remove"
@@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil return nil
} }
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error { // writeRouteMessage sends a route message over AF_ROUTE and waits for the
operation := func() error { // kernel's matching reply, retrying transient failures until budget elapses.
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) // Callers do not need to manage sockets or seq numbers themselves.
if err != nil { func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
return fmt.Errorf("open routing socket: %w", err) expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
} }
defer func() { }()
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err) tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
} }
}() return backoff.Permanent(fmt.Errorf("read: %w", err))
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
} }
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
msgBytes, err := msg.Marshal() continue
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
} }
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
if _, err = unix.Write(fd, msgBytes); err != nil { // Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { // uniquely identifies our own reply among broadcast traffic.
return fmt.Errorf("write: %w", err) if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
} continue
return backoff.Permanent(fmt.Errorf("write: %w", err))
} }
if hdr.Errno != 0 {
respBuf := make([]byte, 2048) return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
} }
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil return nil
} }
return operation
} }
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action, Type: action,
Flags: unix.RTF_UP | routeProtoFlag, Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(), Seq: r.getSeq(),
} }
@@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil return msg, nil
} }
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). // addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() { if addr.Is4() {

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows //go:build !linux && !windows && !darwin
package net package net

View File

@@ -1,24 +0,0 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build windows //go:build (darwin && !ios) || windows
package net package net
@@ -24,17 +24,22 @@ func Init() {
} }
func checkAdvancedRoutingSupport() bool { func checkAdvancedRoutingSupport() bool {
var err error legacyRouting := false
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" { if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val) parsed, err := strconv.ParseBool(val)
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
} }
} }
if legacyRouting || netstack.IsEnabled() { if legacyRouting {
log.Info("advanced routing has been requested to be disabled") log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
return false return false
} }

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !android //go:build !linux && !windows && !darwin
package net package net

25
client/net/env_mobile.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,5 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows //go:build !linux && !windows && !darwin
package net package net

160
client/net/net_darwin.go Normal file
View File

@@ -0,0 +1,160 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error {
} }
// clean up any remaining routes independently of the state file // clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() { if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)

View File

@@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err) return "", fmt.Errorf("get NetBird executable path: %w", err)
} }
hostLine := strings.Join(deduplicatedPatterns, " ") hostList := strings.Join(deduplicatedPatterns, ",")
config := fmt.Sprintf("Host %s\n", hostLine) config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PreferredAuthentications password,publickey,keyboard-interactive\n" config += " PasswordAuthentication yes\n"
config += " PasswordAuthentication yes\n" config += " PubkeyAuthentication yes\n"
config += " PubkeyAuthentication yes\n" config += " BatchMode no\n"
config += " BatchMode no\n" config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) config += " StrictHostKeyChecking no\n"
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n" config += " UserKnownHostsFile NUL\n"
} else { } else {
config += " UserKnownHostsFile /dev/null\n" config += " UserKnownHostsFile /dev/null\n"
} }
config += " CheckHostIP no\n" config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n" config += " LogLevel ERROR\n\n"
return config, nil return config, nil
} }

View File

@@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
} }
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) { func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable // Set force environment variable
t.Setenv(EnvForceSSHConfig, "true") t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -2,7 +2,6 @@ package system
import ( import (
"context" "context"
"net"
"net/netip" "net/netip"
"strings" "strings"
@@ -145,59 +144,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v return v
} }
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks. // GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks)) log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,6 +2,8 @@ package system
import ( import (
"context" "context"
"net"
"net/netip"
"runtime" "runtime"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -42,6 +44,66 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path. // checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) { func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil return []File{}, nil

View File

@@ -0,0 +1,66 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

View File

@@ -119,6 +119,8 @@ server:
# Reverse proxy settings (optional) # Reverse proxy settings (optional)
# reverseProxy: # reverseProxy:
# trustedHTTPProxies: [] # trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
# trustedHTTPProxiesCount: 0 # trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
# trustedPeers: [] # trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.

View File

@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err) require.NoError(t, err)
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
// which is when Receive has actually returned. Without this, the Receive goroutine
// can outlive the test and call t.Logf after teardown, panicking.
receiveDone := make(chan struct{})
t.Cleanup(func() {
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Error("Receive goroutine did not exit after Close")
}
})
t.Cleanup(func() { t.Cleanup(func() {
err := client.Close() err := client.Close()
assert.NoError(t, err, "failed to close flow") assert.NoError(t, err, "failed to close flow")
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
receivedAfterReconnect := make(chan struct{}) receivedAfterReconnect := make(chan struct{})
go func() { go func() {
defer close(receiveDone)
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 { if msg.IsInitiator || len(msg.EventId) == 0 {
return nil return nil

34
go.mod
View File

@@ -17,12 +17,12 @@ require (
github.com/spf13/cobra v1.10.1 github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9 github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1 github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.48.0 golang.org/x/crypto v0.49.0
golang.org/x/sys v0.42.0 golang.org/x/sys v0.42.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.79.3 google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11 google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
) )
@@ -115,13 +115,13 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.32.0 golang.org/x/mod v0.33.0
golang.org/x/net v0.51.0 golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.34.0 golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.19.0 golang.org/x/sync v0.20.0
golang.org/x/term v0.40.0 golang.org/x/term v0.41.0
golang.org/x/time v0.14.0 golang.org/x/time v0.15.0
google.golang.org/api v0.257.0 google.golang.org/api v0.276.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
@@ -131,7 +131,7 @@ require (
) )
require ( require (
cloud.google.com/go/auth v0.17.0 // indirect cloud.google.com/go/auth v0.20.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect
dario.cat/mergo v1.0.1 // indirect dario.cat/mergo v1.0.1 // indirect
@@ -210,8 +210,8 @@ require (
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/googleapis/gax-go/v2 v2.21.0 // indirect
github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/handlers v1.5.2 // indirect
github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect github.com/hack-pad/safejs v0.1.0 // indirect
@@ -295,16 +295,16 @@ require (
github.com/zeebo/blake3 v0.2.3 // indirect github.com/zeebo/blake3 v0.2.3 // indirect
go.mongodb.org/mongo-driver v1.17.9 // indirect go.mongodb.org/mongo-driver v1.17.9 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/otel/sdk v1.43.0 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/image v0.33.0 // indirect golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.34.0 // indirect golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.41.0 // indirect golang.org/x/tools v0.42.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
@@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

80
go.sum
View File

@@ -1,5 +1,5 @@
cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ= cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
@@ -285,10 +285,10 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
@@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
@@ -449,6 +447,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
@@ -664,8 +664,8 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
@@ -707,8 +707,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
@@ -725,8 +725,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -745,11 +745,11 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -761,8 +761,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -811,8 +811,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -824,10 +824,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
@@ -839,8 +839,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -851,19 +851,19 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY=
google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@@ -182,6 +182,23 @@ read_enable_proxy() {
return 0 return 0
} }
read_enable_crowdsec() {
echo "" > /dev/stderr
echo "Do you want to enable CrowdSec IP reputation blocking?" > /dev/stderr
echo "CrowdSec checks client IPs against a community threat intelligence database" > /dev/stderr
echo "and blocks known malicious sources before they reach your services." > /dev/stderr
echo "A local CrowdSec LAPI container will be added to your deployment." > /dev/stderr
echo -n "Enable CrowdSec? [y/N]: " > /dev/stderr
read -r CHOICE < /dev/tty
if [[ "$CHOICE" =~ ^[Yy]$ ]]; then
echo "true"
else
echo "false"
fi
return 0
}
read_traefik_acme_email() { read_traefik_acme_email() {
echo "" > /dev/stderr echo "" > /dev/stderr
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
@@ -297,6 +314,10 @@ initialize_default_values() {
# NetBird Proxy configuration # NetBird Proxy configuration
ENABLE_PROXY="false" ENABLE_PROXY="false"
PROXY_TOKEN="" PROXY_TOKEN=""
# CrowdSec configuration
ENABLE_CROWDSEC="false"
CROWDSEC_BOUNCER_KEY=""
return 0 return 0
} }
@@ -325,6 +346,9 @@ configure_reverse_proxy() {
if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
ENABLE_PROXY=$(read_enable_proxy) ENABLE_PROXY=$(read_enable_proxy)
if [[ "$ENABLE_PROXY" == "true" ]]; then
ENABLE_CROWDSEC=$(read_enable_crowdsec)
fi
fi fi
# Handle external Traefik-specific prompts (option 1) # Handle external Traefik-specific prompts (option 1)
@@ -354,7 +378,7 @@ check_existing_installation() {
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:" echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt && rm -rf crowdsec/"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1 exit 1
fi fi
@@ -375,6 +399,9 @@ generate_configuration_files() {
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
# TCP ServersTransport for PROXY protocol v2 to the proxy backend # TCP ServersTransport for PROXY protocol v2 to the proxy backend
render_traefik_dynamic > traefik-dynamic.yaml render_traefik_dynamic > traefik-dynamic.yaml
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
mkdir -p crowdsec
fi
fi fi
;; ;;
1) 1)
@@ -417,8 +444,12 @@ start_services_and_show_instructions() {
if [[ "$ENABLE_PROXY" == "true" ]]; then if [[ "$ENABLE_PROXY" == "true" ]]; then
# Phase 1: Start core services (without proxy) # Phase 1: Start core services (without proxy)
local core_services="traefik dashboard netbird-server"
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
core_services="$core_services crowdsec"
fi
echo "Starting core services..." echo "Starting core services..."
$DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server $DOCKER_COMPOSE_COMMAND up -d $core_services
sleep 3 sleep 3
wait_management_proxy traefik wait_management_proxy traefik
@@ -438,7 +469,33 @@ start_services_and_show_instructions() {
echo "Proxy token created successfully." echo "Proxy token created successfully."
# Generate proxy.env with the token if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "Registering CrowdSec bouncer..."
local cs_retries=0
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
cs_retries=$((cs_retries + 1))
if [[ $cs_retries -ge 30 ]]; then
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
echo "You can register a bouncer manually later with:" > /dev/stderr
echo " docker exec netbird-crowdsec cscli bouncers add netbird-proxy -o raw" > /dev/stderr
ENABLE_CROWDSEC="false"
break
fi
sleep 2
done
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
CROWDSEC_BOUNCER_KEY=$($DOCKER_COMPOSE_COMMAND exec -T crowdsec \
cscli bouncers add netbird-proxy -o raw 2>/dev/null)
if [[ -z "$CROWDSEC_BOUNCER_KEY" ]]; then
echo "WARNING: Failed to create CrowdSec bouncer key. Skipping CrowdSec setup." > /dev/stderr
ENABLE_CROWDSEC="false"
else
echo "CrowdSec bouncer registered."
fi
fi
fi
render_proxy_env > proxy.env render_proxy_env > proxy.env
# Start proxy service # Start proxy service
@@ -525,11 +582,25 @@ render_docker_compose_traefik_builtin() {
# Generate proxy service section and Traefik dynamic config if enabled # Generate proxy service section and Traefik dynamic config if enabled
local proxy_service="" local proxy_service=""
local proxy_volumes="" local proxy_volumes=""
local crowdsec_service=""
local crowdsec_volumes=""
local traefik_file_provider="" local traefik_file_provider=""
local traefik_dynamic_volume="" local traefik_dynamic_volume=""
if [[ "$ENABLE_PROXY" == "true" ]]; then if [[ "$ENABLE_PROXY" == "true" ]]; then
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"' traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro" traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
local proxy_depends="
netbird-server:
condition: service_started"
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
proxy_depends="
netbird-server:
condition: service_started
crowdsec:
condition: service_healthy"
fi
proxy_service=" proxy_service="
# NetBird Proxy - exposes internal resources to the internet # NetBird Proxy - exposes internal resources to the internet
proxy: proxy:
@@ -539,8 +610,7 @@ render_docker_compose_traefik_builtin() {
- 51820:51820/udp - 51820:51820/udp
restart: unless-stopped restart: unless-stopped
networks: [netbird] networks: [netbird]
depends_on: depends_on:${proxy_depends}
- netbird-server
env_file: env_file:
- ./proxy.env - ./proxy.env
volumes: volumes:
@@ -563,6 +633,35 @@ render_docker_compose_traefik_builtin() {
" "
proxy_volumes=" proxy_volumes="
netbird_proxy_certs:" netbird_proxy_certs:"
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
crowdsec_service="
crowdsec:
image: crowdsecurity/crowdsec:v1.7.7
container_name: netbird-crowdsec
restart: unless-stopped
networks: [netbird]
environment:
COLLECTIONS: crowdsecurity/linux
volumes:
- ./crowdsec:/etc/crowdsec
- crowdsec_db:/var/lib/crowdsec/data
healthcheck:
test: ["CMD", "cscli", "lapi", "status"]
interval: 10s
timeout: 5s
retries: 15
labels:
- traefik.enable=false
logging:
driver: \"json-file\"
options:
max-size: \"500m\"
max-file: \"2\"
"
crowdsec_volumes="
crowdsec_db:"
fi
fi fi
cat <<EOF cat <<EOF
@@ -675,10 +774,10 @@ $traefik_dynamic_volume
options: options:
max-size: "500m" max-size: "500m"
max-file: "2" max-file: "2"
${proxy_service} ${proxy_service}${crowdsec_service}
volumes: volumes:
netbird_data: netbird_data:
netbird_traefik_letsencrypt:${proxy_volumes} netbird_traefik_letsencrypt:${proxy_volumes}${crowdsec_volumes}
networks: networks:
netbird: netbird:
@@ -783,6 +882,14 @@ NB_PROXY_PROXY_PROTOCOL=true
# Trust Traefik's IP for PROXY protocol headers # Trust Traefik's IP for PROXY protocol headers
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF EOF
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
cat <<EOF
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
NB_PROXY_CROWDSEC_API_KEY=$CROWDSEC_BOUNCER_KEY
EOF
fi
return 0 return 0
} }
@@ -1172,6 +1279,17 @@ print_builtin_traefik_instructions() {
echo "" echo ""
echo " *.$NETBIRD_DOMAIN CNAME $NETBIRD_DOMAIN" echo " *.$NETBIRD_DOMAIN CNAME $NETBIRD_DOMAIN"
echo "" echo ""
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "CrowdSec IP Reputation:"
echo " CrowdSec LAPI is running and connected to the community blocklist."
echo " The proxy will automatically check client IPs against known threats."
echo " Enable CrowdSec per-service in the dashboard under Access Control."
echo ""
echo " To enroll in CrowdSec Console (optional, for dashboard and premium blocklists):"
echo " docker exec netbird-crowdsec cscli console enroll <your-enrollment-key>"
echo " Get your enrollment key at: https://app.crowdsec.net"
echo ""
fi
fi fi
return 0 return 0
} }

View File

@@ -7,7 +7,6 @@ import (
"os" "os"
"slices" "slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -16,11 +15,9 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -58,13 +55,6 @@ type Controller struct {
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
} }
type bufferUpdate struct { type bufferUpdate struct {
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err)) log.Fatal(fmt.Errorf("error creating metrics: %w", err))
} }
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{ return &Controller{
repo: newRepository(store), repo: newRepository(store),
metrics: nMetrics, metrics: nMetrics,
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController, proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager, EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
} }
} }
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var ( account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account *types.Account if err != nil {
err error return fmt.Errorf("failed to get account: %v", err)
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
} }
globalStart := time.Now() globalStart := time.Now()
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers() groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start)) c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now() start = time.Now()
var remotePeerNetworkMap *types.NetworkMap remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account. // UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID) return c.sendUpdateAccountPeers(ctx, accountID)
} }
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err return err
} }
var remotePeerNetworkMap *types.NetworkMap remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil return peer, emptyMap, nil, 0, nil
} }
var ( account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account *types.Account if err != nil {
err error return nil, nil, nil, 0, err
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
} }
account.InjectProxyPolicies(ctx) account.InjectProxyPolicies(ctx)
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
} }
var networkMap *types.NetworkMap resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if c.experimentalNetworkMap(accountID) { groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil return peer, networkMap, postureChecks, dnsFwdPort, nil
} }
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain // GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string { func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil { if settings == nil {
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
} }
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) err := c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
} }
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID) return c.bufferSendUpdateAccountPeers(ctx, accountID)
} }
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap, MessageType: network_map.MessageTypeNetworkMap,
}) })
c.peersUpdateManager.CloseChannel(ctx, peerID) c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
} }
return c.bufferSendUpdateAccountPeers(ctx, accountID) return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err return nil, err
} }
var networkMap *types.NetworkMap account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
if c.experimentalNetworkMap(peer.AccountID) { routers := account.GetResourceRoutersMap()
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil) groupIDToUserIDs := account.GetActiveGroupUsers()
} else { networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {

View File

@@ -12,9 +12,6 @@ import (
) )
const ( const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0" DnsForwarderPortMinVersion = "v0.59.0"

View File

@@ -30,6 +30,7 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
}) })
} }
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server { return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }
@@ -1171,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
} }
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t) testAccountManager_NetworkUpdates_SaveGroup(t)
} }
@@ -1231,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t) testAccountManager_NetworkUpdates_DeletePolicy(t)
} }
@@ -1274,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t) testAccountManager_NetworkUpdates_SavePolicy(t)
} }
@@ -1332,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t) testAccountManager_NetworkUpdates_DeletePeer(t)
} }
@@ -1397,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t) testAccountManager_NetworkUpdates_DeleteGroup(t)
} }
@@ -1633,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-2"))
} }
func TestAccount_GetRoutesToSync(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
if err != nil {
t.Fatal(err)
}
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-2": {
ID: "route-2",
Network: prefix2,
NetID: "network-2",
Description: "network-2",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-3": {
ID: "route-3",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
func TestAccount_Copy(t *testing.T) { func TestAccount_Copy(t *testing.T) {
account := &types.Account{ account := &types.Account{
Id: "account1", Id: "account1",
@@ -1824,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
AccountID: "account1", AccountID: "account1",
}, },
}, },
NetworkMapCache: &types.NetworkMapBuilder{},
} }
account.InitOnce()
err := hasNilField(account) err := hasNilField(account)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -2311,6 +2215,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
} }
} }
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
ctx := context.Background()
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanUp)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
// Verify the already-expired peer is excluded at the store level
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
}
// Only the non-expired peer with expiration enabled should be returned
require.Len(t, peers, 1)
assert.Equal(t, "notexpired01", peers[0].ID)
}
func TestAccount_GetInactivePeers(t *testing.T) { func TestAccount_GetInactivePeers(t *testing.T) {
type test struct { type test struct {
name string name string
@@ -3230,6 +3157,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
return manager, updateManager, account, peer1, peer2, peer3 return manager, updateManager, account, peer1, peer2, peer3
} }
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
// wrappers wait for an expected update message. Sized for slow CI runners
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
// seconds. Only runs down on failure; passing tests return immediately
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper() t.Helper()
select { select {
@@ -3248,7 +3182,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
if msg == nil { if msg == nil {
t.Errorf("Received nil update message, expected valid message") t.Errorf("Received nil update message, expected valid message")
} }
case <-time.After(500 * time.Millisecond): case <-time.After(peerUpdateTimeout):
t.Error("Timed out waiting for update message") t.Error("Timed out waiting for update message")
} }
} }

View File

@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -6,8 +6,6 @@ import (
"net/http" "net/http"
"net/netip" "net/netip"
"os" "os"
"strconv"
"time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/rs/cors" "github.com/rs/cors"
@@ -66,15 +64,12 @@ import (
) )
const ( const (
apiPrefix = "/api" apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
setupCreatePATEnabledKey = "NB_SETUP_PAT_ENABLED" setupCreatePATEnabledKey = "NB_SETUP_PAT_ENABLED"
) )
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints // Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil { if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -95,34 +90,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err) return nil, fmt.Errorf("failed to add bypass path: %w", err)
} }
var rateLimitingConfig *middleware.RateLimiterConfig if rateLimiter == nil {
if os.Getenv(rateLimitingEnabledKey) == "true" { log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rpm := 6 rateLimiter = middleware.NewAPIRateLimiter(nil)
if v := os.Getenv(rateLimitingRPMKey); v != "" { rateLimiter.SetEnabled(false)
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
} }
authMiddleware := middleware.NewAuthMiddleware( authMiddleware := middleware.NewAuthMiddleware(
@@ -130,7 +101,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth, accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups, accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth, accountManager.GetUserFromUserAuth,
rateLimitingConfig, rateLimiter,
appMetrics.GetMeter(), appMetrics.GetMeter(),
) )

View File

@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
} }

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth" serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -42,14 +43,9 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc, ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc, syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc, getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig, rateLimiter *APIRateLimiter,
meter metric.Meter, meter metric.Meter,
) *AuthMiddleware { ) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker var patUsageTracker *PATUsageTracker
if meter != nil { if meter != nil {
var err error var err error
@@ -87,17 +83,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType { switch authType {
case "bearer": case "bearer":
request, err := m.checkJWTFromRequest(r, authHeader) if err := m.checkJWTFromRequest(r, authHeader); err != nil {
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return return
} }
h.ServeHTTP(w, r)
h.ServeHTTP(w, request)
case "token": case "token":
request, err := m.checkPATFromRequest(r, authHeader) if err := m.checkPATFromRequest(r, authHeader); err != nil {
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized // Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok { if _, ok := status.FromError(err); !ok {
@@ -106,7 +99,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
h.ServeHTTP(w, request) h.ServeHTTP(w, r)
default: default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return return
@@ -115,19 +108,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
} }
// CheckJWTFromRequest checks if the JWT is valid // CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts) token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error // If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return r, fmt.Errorf("error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
ctx := r.Context() ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil { if err != nil {
return r, err return err
} }
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
@@ -143,7 +136,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
// we need to call this method because if user is new, we will automatically add it to existing or create a new account // we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth) accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil { if err != nil {
return r, err return err
} }
if userAuth.AccountId != accountId { if userAuth.AccountId != accountId {
@@ -153,7 +146,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil { if err != nil {
return r, err return err
} }
err = m.syncUserJWTGroups(ctx, userAuth) err = m.syncUserJWTGroups(ctx, userAuth)
@@ -164,41 +157,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
_, err = m.getUserFromUserAuth(ctx, userAuth) _, err = m.getUserFromUserAuth(ctx, userAuth)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return r, err return err
} }
return nbcontext.SetUserAuthInRequest(r, userAuth), nil // propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
} }
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromPATRequest(authHeaderParts) token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil { if err != nil {
return r, fmt.Errorf("error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
if m.patUsageTracker != nil { if m.patUsageTracker != nil {
m.patUsageTracker.IncrementUsage(token) m.patUsageTracker.IncrementUsage(token)
} }
if m.rateLimiter != nil && !isTerraformRequest(r) { if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
if !m.rateLimiter.Allow(token) { return status.Errorf(status.TooManyRequests, "too many requests")
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
} }
ctx := r.Context() ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil { if err != nil {
return r, fmt.Errorf("invalid Token: %w", err) return fmt.Errorf("invalid Token: %w", err)
} }
if time.Now().After(pat.GetExpirationDate()) { if time.Now().After(pat.GetExpirationDate()) {
return r, fmt.Errorf("token expired") return fmt.Errorf("token expired")
} }
err = m.authManager.MarkPATUsed(ctx, pat.ID) err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil { if err != nil {
return r, err return err
} }
userAuth := auth.UserAuth{ userAuth := auth.UserAuth{
@@ -216,7 +209,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
} }
} }
return nbcontext.SetUserAuthInRequest(r, userAuth), nil // propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
} }
func isTerraformRequest(r *http.Request) bool { func isTerraformRequest(r *http.Request) bool {

View File

@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT, GetPATInfoFunc: mockGetAccountInfoFromPAT,
} }
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockAuth, mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
nil, disabledLimiter,
nil, nil,
) )
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
rateLimitConfig, NewAPIRateLimiter(rateLimitConfig),
nil, nil,
) )
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT, GetPATInfoFunc: mockGetAccountInfoFromPAT,
} }
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware( authMiddleware := NewAuthMiddleware(
mockAuth, mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil return &types.User{}, nil
}, },
nil, disabledLimiter,
nil, nil,
) )

View File

@@ -4,14 +4,27 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"os"
"strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter // RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct { type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished // RequestsPerMinute defines the rate at which tokens are replenished
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
} }
} }
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time // limiterEntry holds a rate limiter and its last access time
type limiterEntry struct { type limiterEntry struct {
limiter *rate.Limiter limiter *rate.Limiter
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry limiters map[string]*limiterEntry
mu sync.RWMutex mu sync.RWMutex
stopChan chan struct{} stopChan chan struct{}
enabled atomic.Bool
} }
// NewAPIRateLimiter creates a new API rate limiter with the given configuration // NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry), limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
} }
rl.enabled.Store(true)
go rl.cleanupLoop() go rl.cleanupLoop()
return rl return rl
} }
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed // Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool { func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key) limiter := rl.getLimiter(key)
return limiter.Allow() return limiter.Allow()
} }
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key // Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled // Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key) limiter := rl.getLimiter(key)
return limiter.Wait(ctx) return limiter.Wait(ctx)
} }
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded. // Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r) clientIP := getClientIP(r)
if !rl.Allow(clientIP) { if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,8 +1,10 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again // Should be allowed again
assert.True(t, rl.Allow("test-key")) assert.True(t, rl.Allow("test-key"))
} }
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }

View File

@@ -66,14 +66,14 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
} }
// Create a new Admin SDK Directory service client // Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) credentialsOption, err := getGoogleCredentialsOption(ctx, config.ServiceAccountKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
service, err := admin.NewService(context.Background(), service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserReadonlyScope), option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
option.WithCredentials(adminCredentials), credentialsOption,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -218,39 +218,32 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e
return nil return nil
} }
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. // getGoogleCredentialsOption returns the google.golang.org/api option carrying
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // Google credentials derived from the provided serviceAccountKey.
// If that fails, it falls back to using the default Google credentials path. // It decodes the base64-encoded serviceAccountKey and uses it as the credentials JSON.
// It returns the retrieved credentials or an error if unsuccessful. // If the key is empty, it falls back to the default Google credentials path.
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { func getGoogleCredentialsOption(ctx context.Context, serviceAccountKey string) (option.ClientOption, error) {
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err) return nil, fmt.Errorf("failed to decode service account key: %w", err)
} }
creds, err := google.CredentialsFromJSON( if len(decodeKey) > 0 {
context.Background(), return option.WithAuthCredentialsJSON(option.ServiceAccount, decodeKey), nil
decodeKey,
admin.AdminDirectoryUserReadonlyScope,
)
if err == nil {
// No need to fallback to the default Google credentials path
return creds, nil
} }
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) log.WithContext(ctx).Debug("no service account key provided, falling back to default google credentials location")
log.WithContext(ctx).Debug("falling back to default google credentials location")
creds, err = google.FindDefaultCredentials( creds, err := google.FindDefaultCredentials(
context.Background(), ctx,
admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryUserReadonlyScope,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
return creds, nil return option.WithCredentials(creds), nil
} }
// parseGoogleWorkspaceUser parse google user to UserData. // parseGoogleWorkspaceUser parse google user to UserData.

View File

@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
} }
// expired peers come separately. // expired peers come separately.
if len(networkMap.GetOfflinePeers()) != 1 { if len(networkMap.GetOfflinePeers()) != 2 {
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer") t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
} }
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=" expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="

View File

@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry { for _, peer := range peersWithExpiry {
if peer.Status.LoginExpired {
continue
}
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired { if expired {
peers = append(peers, peer) peers = append(peers, peer)

View File

@@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
testGetNetworkMapGeneral(t) testGetNetworkMapGeneral(t)
} }
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) { func testGetNetworkMapGeneral(t *testing.T) {
manager, _, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
@@ -1016,11 +1011,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
} }
} }
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) { func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t) testUpdateAccountPeers(t)
} }
@@ -1600,7 +1590,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
} }
func Test_LoginPeer(t *testing.T) { func Test_LoginPeer(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
} }
@@ -1907,7 +1896,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1929,7 +1918,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1994,7 +1983,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2012,7 +2001,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2058,7 +2047,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2076,7 +2065,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2113,7 +2102,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2131,7 +2120,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -5,6 +5,7 @@ import (
_ "embed" _ "embed"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
var isUpdate = policy.ID != "" var isUpdate = policy.ID != ""
var updateAccountPeers bool var updateAccountPeers bool
var action = activity.PolicyAdded var action = activity.PolicyAdded
var unchanged bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
if err != nil { if err != nil {
return err return err
} }
saveFunc := transaction.CreatePolicy
if isUpdate { if isUpdate {
action = activity.PolicyUpdated if policy.Equal(existingPolicy) {
saveFunc = transaction.SavePolicy logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
} unchanged = true
return nil
}
if err = saveFunc(ctx, policy); err != nil { action = activity.PolicyUpdated
return err
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
} }
return transaction.IncrementNetworkSerial(ctx, accountID) return transaction.IncrementNetworkSerial(ctx, accountID)
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return nil, err return nil, err
} }
if unchanged {
return policy, nil
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err return err
} }
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil { if err != nil {
return err return err
} }
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
} }
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. // arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
if isUpdate { for _, rule := range policy.Rules {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
if err != nil {
return false, err
}
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil return true, nil
} }
} }
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil return true, nil
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
} }
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules. For updates it returns
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { // the existing policy loaded from the store so callers can avoid a second read.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
var existingPolicy *types.Policy
if policy.ID != "" { if policy.ID != "" {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) var err error
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return nil, err
} }
// TODO: Refactor to support multiple rules per policy // TODO: Refactor to support multiple rules per policy
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if rule.ID != "" && !existingRuleIDs[rule.ID] { if rule.ID != "" && !existingRuleIDs[rule.ID] {
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
} }
} }
} else { } else {
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil { if err != nil {
return err return nil, err
} }
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return nil, err
} }
for i, rule := range policy.Rules { for i, rule := range policy.Rules {
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
} }
return nil return existingPolicy, nil
} }
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.

View File

@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -2,10 +2,8 @@ package server
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip" "net/netip"
"sort"
"testing" "testing"
"time" "time"
@@ -1840,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
}, },
} }
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("check applied policies for the route", func(t *testing.T) { t.Run("check applied policies for the route", func(t *testing.T) {
route1 := account.Routes["route1"] route1 := account.Routes["route1"]
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
@@ -1858,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
assert.Len(t, policies, 0) assert.Len(t, policies, 0)
}) })
t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 4)
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "route1:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "route1:peerA",
},
}
additionalFirewallRule := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "tcp",
Port: 80,
RouteID: "route4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "all",
RouteID: "route4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2)
for _, rule := range expectedRoutesFirewallRules {
rule.RouteID = "route1:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
assert.Len(t, routesFirewallRules, 3)
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: existingNetwork.String(),
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "route2",
},
{
SourceRanges: []string{"0.0.0.0/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
{
SourceRanges: []string{"::/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, routesFirewallRules, 0)
})
}
// orderList is a helper function to sort a list of strings
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
for _, rule := range ruleList {
sort.Strings(rule.SourceRanges)
}
return ruleList
} }
func TestRouteAccountPeersUpdate(t *testing.T) { func TestRouteAccountPeersUpdate(t *testing.T) {
@@ -2070,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
@@ -2107,7 +1990,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2127,7 +2010,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2145,7 +2028,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2185,7 +2068,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2225,7 +2108,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2665,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
}, },
} }
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("validate applied policies for different network resources", func(t *testing.T) { t.Run("validate applied policies for different network resources", func(t *testing.T) {
// Test case: Resource1 is directly applied to the policy (policyResource1) // Test case: Resource1 is directly applied to the policy (policyResource1)
policies := account.GetPoliciesForNetworkResource("resource1") policies := account.GetPoliciesForNetworkResource("resource1")
@@ -2693,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
policies = account.GetPoliciesForNetworkResource("resource6") policies = account.GetPoliciesForNetworkResource("resource6")
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
}) })
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
resourcePoliciesMap := account.GetResourcePoliciesMap()
resourceRoutersMap := account.GetResourceRoutersMap()
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 4)
assert.Len(t, sourcePeers, 5)
expectedFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "resource2:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "resource2:peerA",
},
}
additionalFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "tcp",
Port: 80,
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
// peerD is also the routing peer for resource2
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 2)
for _, rule := range expectedFirewallRules {
rule.RouteID = "resource2:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
assert.Len(t, sourcePeers, 3)
// peerE is a single routing peer for resource1 and resource3
// PeerE should only receive rules for resource1 since resource3 has no applied policy
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 2)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: "10.10.10.0/24",
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "resource1:peerE",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, firewallRules, 0)
// peerL is the single routing peer for resource5
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 1)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.29.67/32"},
Action: "accept",
Destination: "10.12.12.1/32",
Protocol: "tcp",
Port: 8080,
RouteID: "resource5:peerL",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 0)
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 2)
})
} }

View File

@@ -1196,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
account.NameServerGroups[ns.ID] = &ns account.NameServerGroups[ns.ID] = &ns
} }
account.NameServerGroupsG = nil account.NameServerGroupsG = nil
account.InitOnce()
return &account, nil return &account, nil
} }
@@ -1635,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
if sExtraIntegratedValidatorGroups.Valid { if sExtraIntegratedValidatorGroups.Valid {
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
} }
account.InitOnce()
return &account, nil return &account, nil
} }
@@ -3310,7 +3308,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx. result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
Find(&peers, accountIDCondition, accountID) Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)

View File

@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
{ {
name: "should retrieve peers for an existing account ID", name: "should retrieve peers for an existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 4, expectedCount: 5,
}, },
{ {
name: "should return no peers for a non-existing account ID", name: "should return no peers for a non-existing account ID",
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
name: "should filter peers by partial name", name: "should filter peers by partial name",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
nameFilter: "host", nameFilter: "host",
expectedCount: 3, expectedCount: 4,
}, },
{ {
name: "should filter peers by ip", name: "should filter peers by ip",
@@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
accountID string accountID string
expectedCount int expectedCount int
expectedPeerIDs []string
}{ }{
{ {
name: "should retrieve peers with expiration for an existing account ID", name: "should retrieve only non-expired peers with expiration enabled",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1, expectedCount: 1,
expectedPeerIDs: []string{"notexpired01"},
}, },
{ {
name: "should return no peers with expiration for a non-existing account ID", name: "should return no peers with expiration for a non-existing account ID",
@@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID) peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, peers, tt.expectedCount) require.Len(t, peers, tt.expectedCount)
for i, peer := range peers {
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
}
}) })
} }
} }
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
require.NoError(t, err)
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
}
}
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanup) t.Cleanup(cleanup)
@@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
name: "should retrieve peers for another valid account ID and user ID", name: "should retrieve peers for another valid account ID and user ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
userID: "edafee4e-63fb-11ec-90d6-0242ac120003", userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
expectedCount: 2, expectedCount: 3,
}, },
{ {
name: "should return no peers for existing account ID with empty user ID", name: "should return no peers for existing account ID with empty user ID",

View File

@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
} }
}) })
h.ServeHTTP(w, r.WithContext(ctx)) // Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
req := r.WithContext(ctx)
h.ServeHTTP(w, req)
close(handlerDone) close(handlerDone)
userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) ctx = req.Context()
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 { if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())

View File

@@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@@ -8,7 +8,6 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@@ -27,7 +26,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -110,16 +108,9 @@ type Account struct {
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
ReverseProxyFreeDomainNonce string ReverseProxyFreeDomainNonce string
} }
func (a *Account) InitOnce() {
a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only // this class is used by gorm only
type PrimaryAccountInfo struct { type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool IsDomainPrimaryAccount bool
@@ -155,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
o.SignupFormPending == onboarding.SignupFormPending o.SignupFormPending == onboarding.SignupFormPending
} }
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
// If the given is not a routing peer, then the lists are empty.
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peer := a.GetPeer(peerID)
if peer == nil {
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
if r.Enabled {
r.Peer = peer.Key
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
}
}
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
var routes []*route.Route var routes []*route.Route
@@ -276,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID] return a.Groups[groupID]
} }
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
if _, ok := validatedPeersMap[peerID]; !ok {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
}
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
nm := &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: a.Network.Copy(),
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
}
}
return nm
}
func (a *Account) addNetworksRoutingPeers( func (a *Account) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route, networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer, peer *nbpeer.Peer,
@@ -421,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers(
return peersToConnect return peersToConnect
} }
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := account.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) {
return true
}
}
return false
}
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
for _, peer := range account.Peers { for _, peer := range account.Peers {
label, err := GetPeerHostLabel(peer.Name, peerLabels) label, err := GetPeerHostLabel(peer.Name, peerLabels)
@@ -800,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps return grps
} }
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.GetPeerGroups(peerID)
enabled := true
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
return enabled
}
func (a *Account) GetPeerGroups(peerID string) LookupMap { func (a *Account) GetPeerGroups(peerID string) LookupMap {
groupList := make(LookupMap) groupList := make(LookupMap)
for groupID, group := range a.Groups { for groupID, group := range a.Groups {
@@ -941,8 +684,6 @@ func (a *Account) Copy() *Account {
NetworkResources: networkResources, NetworkResources: networkResources,
Services: services, Services: services,
Onboarding: a.Onboarding, Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
Domains: domains, Domains: domains,
} }
} }
@@ -1304,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
return nil return nil
} }
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
for _, route := range enabledRoutes {
// If no access control groups are specified, accept all traffic.
if len(route.AccessControlGroups) == 0 {
defaultPermit := getDefaultPermit(route)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := a.getDistributionGroupsPeers(route)
for _, accessGroup := range route.AccessControlGroups {
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule var fwRules []*RouteFirewallRule
for _, policy := range policies { for _, policy := range policies {
@@ -1387,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
return distributionGroupPeers return distributionGroupPeers
} }
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range route.Groups {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if route.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: route.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: route.Domains,
IsDynamic: route.IsDynamic(),
RouteID: route.ID,
}
rules = append(rules, &rule)
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
if route.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups // GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
// and returns a list of policies that have rules with destinations matching the specified groups. // and returns a list of policies that have rules with destinations matching the specified groups.
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
@@ -1508,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
return resourcePolicies return resourcePolicies
} }
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := routers[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
}
}
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups { for _, groupID := range groups {
@@ -1658,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
return result return result
} }
// getNetworkResourcesRoutes convert the network resources list to routes list.
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
resourceAppliedPolicies := resourcePolicies[resource.ID]
var routes []*route.Route
// distribute the resource routes only if there is policy applied to it
if len(resourceAppliedPolicies) > 0 {
peer := a.GetPeer(peerId)
if peer != nil {
routes = append(routes, resource.ToRoute(peer, router))
}
}
return routes
}
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
routers := make(map[string]map[string]*routerTypes.NetworkRouter) routers := make(map[string]map[string]*routerTypes.NetworkRouter)

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -19,7 +17,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
require.Len(t, result, 0) require.Len(t, result, 0)
} }
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
Enabled: true,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
Enabled: true,
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
account := getBasicAccountsWithResource()
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router2Id",
NetworkID: network1ID,
AccountID: accID,
Peer: "router2Id",
})
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -1,47 +0,0 @@
package types
import (
"context"
"sync"
)
type Holder struct {
mu sync.RWMutex
accounts map[string]*Account
}
func NewHolder() *Holder {
return &Holder{
accounts: make(map[string]*Account),
}
}
func (h *Holder) GetAccount(id string) *Account {
h.mu.RLock()
defer h.mu.RUnlock()
return h.accounts[id]
}
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
a := h.accounts[account.Id]
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
return
}
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(ctx, id)
if err != nil {
return nil, err
}
h.accounts[id] = account
return account, nil
}

Some files were not shown because too many files have changed in this diff Show More