Compare commits

..

4 Commits

Author SHA1 Message Date
pascal
5a8dbef89b extract peer host method 2026-04-17 18:43:06 +02:00
pascal
569ebb400b fix error handling 2026-04-17 18:35:42 +02:00
pascal
8ec17daf3a fix testing tools 2026-04-17 18:23:06 +02:00
pascal
8bccbf9304 use peer fqdn on https 2026-04-17 18:10:05 +02:00
226 changed files with 4269 additions and 8607 deletions

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT): $(GOLANGCI_LINT):
@echo "Installing golangci-lint..." @echo "Installing golangci-lint..."
@mkdir -p ./bin @mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push) # Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT) lint: $(GOLANGCI_LINT)

View File

@@ -8,7 +8,6 @@ import (
"os" "os"
"slices" "slices"
"sync" "sync"
"time"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -16,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -28,7 +26,6 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -71,30 +68,7 @@ type Client struct {
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
} }
func (c *Client) RenewTun(fd int) error { func (c *Client) RenewTun(fd int) error {
cc := c.getConnectClient() if c.connectClient == nil {
if cc == nil {
return fmt.Errorf("engine not running") return fmt.Errorf("engine not running")
} }
e := cc.Engine() e := c.connectClient.Engine()
if e == nil { if e == nil {
return fmt.Errorf("engine not initialized") return fmt.Errorf("engine not initialized")
} }
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd) return e.RenewTun(fd)
} }
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level // SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() { func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
} }
func (c *Client) Networks() *NetworkArray { func (c *Client) Networks() *NetworkArray {
cc := c.getConnectClient() if c.connectClient == nil {
if cc == nil {
log.Error("not connected") log.Error("not connected")
return nil return nil
} }
engine := cc.Engine() engine := c.connectClient.Engine()
if engine == nil { if engine == nil {
log.Error("could not get engine") log.Error("could not get engine")
return nil return nil
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
} }
func (c *Client) getRouteManager() (routemanager.Manager, error) { func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.getConnectClient() client := c.connectClient
if client == nil { if client == nil {
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }

View File

@@ -7,5 +7,4 @@ package android
type PlatformFiles interface { type PlatformFiles interface {
ConfigurationFilePath() string ConfigurationFilePath() string
StateFilePath() string StateFilePath() string
CacheDir() string
} }

View File

@@ -13,8 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -31,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -98,7 +97,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store) peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl) settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager) jobManager := job.NewJobManager(nil, store, peersmanager)

View File

@@ -1,11 +0,0 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -1,260 +0,0 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -1,49 +0,0 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -1,25 +0,0 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err) log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
} }
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
} }
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded // attempt to delete state only if all other operations succeeded
if merr == nil { if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
return fmt.Errorf("flush allow input netbird rules: %w", err) return fmt.Errorf("flush allow input netbird rules: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
@@ -41,8 +40,6 @@ const (
chainNameForward = "FORWARD" chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward" chainNameMangleForward = "netbird-mangle-forward"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept" userDataAcceptInputRule = "inputaccept"
@@ -136,10 +133,6 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
} }
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil { if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
} }
@@ -287,10 +280,6 @@ func (r *router) createContainers() error {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err) log.Errorf("failed to refresh rules: %s", err)
} }
@@ -1330,13 +1319,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false return false
} }
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip all iptables-managed tables in the ip family // Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false return false

View File

@@ -3,9 +3,6 @@
package uspfilter package uspfilter
import ( import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager) return m.nativeFirewall.Close(stateManager)
} }
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil return nil
} }
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird() return m.nativeFirewall.AllowNetbird()
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -9,7 +9,6 @@ import (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() wgaddr.Address Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device

View File

@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
NameFunc func() string
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
func (i *IFaceMock) Name() string {
if i.NameFunc == nil {
return "wgtest"
}
return i.NameFunc()
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device { func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil { if i.GetWGDeviceFunc == nil {
return nil return nil

View File

@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++ ipv6Count++
} }
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The assert.Equal(t, packetsPerFamily, ipv4Count)
// routing-correctness checks above are the real assertions; the counts assert.Equal(t, packetsPerFamily, ipv6Count)
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
} }
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface // Close closes the tunnel interface
func (w *WGIface) Close() error { func (w *WGIface) Close() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error var result *multierror.Error
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }
// Release w.mu before calling w.tun.Close(): the underlying if err := w.tun.Close(); err != nil {
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
} }

View File

@@ -1,113 +0,0 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
} }
if u.address.Network.Contains(a) { if u.address.Network.Contains(a) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted) u.addrCache.Store(addr.String(), isRouted)
if isRouted { if isRouted {
// Extra log, as the error only shows up with ICE logging enabled // Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix) log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
} }
} }

View File

@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
stateFilePath string, stateFilePath string,
cacheDir string,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath, StateFilePath: stateFilePath,
TempDir: cacheDir,
} }
return c.run(mobileDependency, nil, "") return c.run(mobileDependency, nil, "")
} }
@@ -340,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager) c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,6 +16,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"slices"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -30,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
) )
const readmeContent = `Netbird debug bundle const readmeContent = `Netbird debug bundle
@@ -232,7 +234,6 @@ type BundleGenerator struct {
statusRecorder *peer.Status statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse syncResponse *mgmProto.SyncResponse
logPath string logPath string
tempDir string
cpuProfile []byte cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter clientMetrics MetricsExporter
@@ -255,7 +256,6 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse SyncResponse *mgmProto.SyncResponse
LogPath string LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter ClientMetrics MetricsExporter
@@ -275,7 +275,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder, statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse, syncResponse: deps.SyncResponse,
logPath: deps.LogPath, logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile, cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus, refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics, clientMetrics: deps.ClientMetrics,
@@ -288,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location. // Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) { func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip") bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil { if err != nil {
return "", fmt.Errorf("create zip file: %w", err) return "", fmt.Errorf("create zip file: %w", err)
} }
@@ -374,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err) log.Errorf("failed to add wg show output: %v", err)
} }
if err := g.addPlatformLog(); err != nil { if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
log.Errorf("failed to add logs to debug bundle: %v", err) if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
} }
if err := g.addUpdateLogs(); err != nil { if err := g.addUpdateLogs(); err != nil {

View File

@@ -1,41 +0,0 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -1,25 +0,0 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -3,12 +3,10 @@ package debug
import ( import (
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci") t.Skip("Skipping upload test on docker ci")
} }
testDir := t.TempDir() testDir := t.TempDir()
addr := reserveLoopbackPort(t) testURL := "http://localhost:8080"
testURL := "http://" + addr
t.Setenv("SERVER_URL", testURL) t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir) t.Setenv("STORE_DIR", testDir)
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err) t.Errorf("Failed to stop server: %v", err)
} }
}) })
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile") file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content") fileContent := []byte("test file content")
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent) require.Equal(t, fileContent, createdFileContent)
} }
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -13,7 +13,6 @@ import (
const ( const (
defaultResolvConfPath = "/etc/resolv.conf" defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
) )
type resolvConf struct { type resolvConf struct {

View File

@@ -1,10 +1,7 @@
package dns package dns
import ( import (
"context"
"fmt" "fmt"
"math"
"net"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
} }
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) { if !c.isHandlerMatch(qname, entry) {
continue continue
} }
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime)) cw.response.Len(), meta, time.Since(startTime))
} }
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch { switch {
case entry.Pattern == ".": case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
} }
} }
} }
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test package dns_test
import ( import (
"context"
"net"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
}) })
} }
} }
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, reason, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err) return nil, fmt.Errorf("get os dns manager type: %w", err)
} }
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason) log.Infof("System DNS manager discovered: %s", osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager) mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil { if err != nil {
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
} }
} }
func getOSDNSManagerType() (osManagerType, string, error) { func getOSDNSManagerType() (osManagerType, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
file, err := os.Open(defaultResolvConfPath) file, err := os.Open(defaultResolvConfPath)
if err != nil { if err != nil {
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
} }
defer func() { defer func() {
if cerr := file.Close(); cerr != nil { if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr) log.Errorf("close file %s: %s", defaultResolvConfPath, err)
} }
}() }()
var rejected []string
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
continue continue
} }
if text[0] != '#' { if text[0] != '#' {
break return fileManager, nil
} }
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" { if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return mgr, reason, nil, nil return netbirdManager, nil
} else if rej != "" { }
rejected = append(rejected, rej) if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
} }
} }
if err := scanner.Err(); err != nil && err != io.EOF { if err := scanner.Err(); err != nil && err != io.EOF {
return 0, "", nil, fmt.Errorf("scan: %w", err) return 0, fmt.Errorf("scan: %w", err)
} }
return 0, "", rejected, nil
return fileManager, nil
} }
// matchResolvConfHeader inspects a single comment line. Returns either a // checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
func checkStub() bool { func checkStub() bool {
rConf, err := parseDefaultResolvConf() rConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err) log.Warnf("failed to parse resolv conf: %s", err)
return true return true
} }
@@ -178,36 +139,3 @@ func checkStub() bool {
return false return false
} }
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -1,76 +0,0 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,83 +2,40 @@ package mgmt
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"net/url" "net/url"
"os"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const ( const dnsTimeout = 5 * time.Second
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing. // Resolver caches critical NetBird infrastructure domains
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct { type Resolver struct {
records map[dns.Question]*cachedRecord records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex mutex sync.RWMutex
}
chain ChainResolver type ipsResponse struct {
chainMaxPriority int ips []netip.Addr
refreshGroup singleflight.Group err error
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
} }
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question]*cachedRecord), records: make(map[dns.Question][]dns.RR),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
} }
} }
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver" return "MgmtCacheResolver"
} }
// SetChainResolver wires the handler chain used to refresh stale cache entries. // ServeDNS implements dns.Handler interface.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
m.continueToNext(w, r) m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
m.mutex.RLock() m.mutex.RLock()
cached, found := m.records[question] records, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
m.mutex.RUnlock() m.mutex.RUnlock()
if !found { if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetReply(r) resp.SetReply(r)
resp.Authoritative = false resp.Authoritative = false
resp.RecursionAvailable = true resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// AddDomain resolves a domain and stores its A/AAAA records in the cache. // AddDomain manually adds a domain to cache by resolving it.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout) ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel() defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
if errA != nil && errAAAA != nil { return err
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
} }
if len(aRecords) == 0 && len(aaaaRecords) == 0 { var aRecords, aaaaRecords []dns.RR
if err := errors.Join(errA, errAAAA); err != nil { for _, ip := range ips {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
} }
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
} }
now := time.Now()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) if len(aRecords) > 0 {
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords)) d.SafeString(), len(aRecords), len(aaaaRecords))
return nil return nil
} }
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
// untouched on error. Caller holds m.mutex. log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} resultChan := make(chan *ipsResponse, 1)
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per go func() {
// unique in-flight key; bursty stale hits share its channel. expected is the ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
// cachedRecord pointer observed by the caller; the refresh only mutates the resultChan <- &ipsResponse{
// cache if that pointer is still the one stored, so a stale in-flight refresh err: err,
// can't clobber a newer entry written by AddDomain or a competing refresh. ips: ips,
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
} }
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", }()
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
} }
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale. if resp.err != nil {
if len(records) == 0 { return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
} }
return resp.ips, nil
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
} }
// PopulateFromConfig extracts and caches domains from the client configuration. // PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} aQuestion := dns.Question{
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} Name: dnsName,
delete(m.records, qA) Qtype: dns.TypeA,
delete(m.records, qAAAA) Qclass: dns.ClassINET,
delete(m.refreshing, qA) }
delete(m.refreshing, qAAAA) delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString()) log.Debugf("removed domain=%s from cache", d.SafeString())
return nil return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains return domains
} }
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains()) assert.False(t, resolver.MatchSubdomains())
} }
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) { func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver() mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,

View File

@@ -26,7 +26,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
@@ -141,7 +140,6 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config ProfileConfig *profilemanager.Config
LogPath string LogPath string
TempDir string
} }
// EngineServices holds the external service dependencies required by the Engine. // EngineServices holds the external service dependencies required by the Engine.
@@ -571,7 +569,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx) e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed()) e.srWatcher.Start()
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
@@ -605,8 +603,6 @@ func (e *Engine) createFirewall() error {
return nil return nil
} }
firewalld.SetParentContext(e.ctx)
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil { if err != nil {
@@ -1099,7 +1095,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder, StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: e.config.LogPath, LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics, ClientMetrics: e.clientMetrics,
RefreshStatus: func() { RefreshStatus: func() {
e.RunHealthProbes(true) e.RunHealthProbes(true)

View File

@@ -25,7 +25,6 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
@@ -58,6 +57,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store) peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager) jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)

View File

@@ -22,8 +22,4 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager DnsManager dns.IosDnsManager
FileDescriptor int32 FileDescriptor int32
StateFilePath string StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
} }

View File

@@ -185,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
forceRelay := IsForceRelayed() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
if !forceRelay { workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() if err != nil {
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) return err
if err != nil {
return err
}
conn.workerICE = workerICE
} }
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !forceRelay { if !isForceRelayed() {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
} }
@@ -254,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel() conn.wgWatcherCancel()
} }
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
if conn.workerICE != nil { conn.workerICE.Close()
conn.workerICE.Close()
}
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn() err := conn.wgProxyRelay.CloseConn()
@@ -299,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate() conn.dumpState.RemoteCandidate()
if conn.workerICE != nil { conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
} }
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -719,35 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusConnecting return StatusConnecting
} }
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports. func (conn *Conn) isConnectedOnAllWay() (connected bool) {
// // would be better to protect this with a mutex, but it could cause deadlock with Close function
// The result is a tri-state:
// - ConnStatusConnected: all available transports are up
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
// - ConnStatusDisconnected: no working transport
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
defer func() { defer func() {
if status == guard.ConnStatusDisconnected { if !connected {
conn.logTraceConnState() conn.logTraceConnState()
} }
}() }()
iceWorkerCreated := conn.workerICE != nil // For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
var iceInProgress bool return conn.statusRelay.Get() == worker.StatusConnected
if iceWorkerCreated {
iceInProgress = conn.workerICE.InProgress()
} }
return evalConnStatus(connStatusInputs{ // For non-JS platforms: check ICE connection status
forceRelay: IsForceRelayed(), if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(), return false
relayConnected: conn.statusRelay.Get() == worker.StatusConnected, }
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
iceWorkerCreated: iceWorkerCreated, // If relay is supported with peer, it must also be connected
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected, if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
iceInProgress: iceInProgress, if conn.statusRelay.Get() == worker.StatusDisconnected {
}) return false
}
}
return true
} }
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
@@ -935,43 +926,3 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
// "Relay up and needed" — the peer uses relay and the transport is connected.
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
if in.forceRelay {
return boolToConnStatus(relayUsedAndUp)
}
// Remote peer doesn't support ICE, or we haven't created the worker yet:
// relay is the only possible transport.
if !in.remoteSupportsICE || !in.iceWorkerCreated {
return boolToConnStatus(relayUsedAndUp)
}
// ICE counts as "up" when the status is anything other than Disconnected, OR
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
iceUp := in.iceStatusConnecting || in.iceInProgress
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
relayOK := !in.peerUsesRelay || in.relayConnected
switch {
case iceUp && relayOK:
return guard.ConnStatusConnected
case relayUsedAndUp:
// Relay is up but ICE is down — partially connected.
return guard.ConnStatusPartiallyConnected
default:
return guard.ConnStatusDisconnected
}
}
func boolToConnStatus(connected bool) guard.ConnStatus {
if connected {
return guard.ConnStatusConnected
}
return guard.ConnStatusDisconnected
}

View File

@@ -13,20 +13,6 @@ const (
StatusConnected StatusConnected
) )
// connStatusInputs is the primitive-valued snapshot of the state that drives the
// tri-state connection classification. Extracted so the decision logic can be unit-tested
// without constructing full Worker/Handshaker objects.
type connStatusInputs struct {
forceRelay bool // NB_FORCE_RELAY or JS/WASM
peerUsesRelay bool // remote peer advertises relay support AND local has relay
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
remoteSupportsICE bool // remote peer sent ICE credentials
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
iceStatusConnecting bool // statusICE is anything other than Disconnected
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int32

View File

@@ -1,201 +0,0 @@
package peer
import (
"testing"
"github.com/netbirdio/netbird/client/internal/peer/guard"
)
func TestEvalConnStatus_ForceRelay(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "force relay, peer uses relay, relay up",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: true,
},
want: guard.ConnStatusConnected,
},
{
name: "force relay, peer uses relay, relay down",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: false,
},
want: guard.ConnStatusDisconnected,
},
{
name: "force relay, peer does NOT use relay - disconnected forever",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: false,
relayConnected: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "remote does not support ICE, peer uses relay, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer uses relay, relay down",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE worker not yet created, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: true,
iceWorkerCreated: false,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer does not use relay",
in: connStatusInputs{
peerUsesRelay: false,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
base := connStatusInputs{
remoteSupportsICE: true,
iceWorkerCreated: true,
}
tests := []struct {
name string
mutator func(*connStatusInputs)
want guard.ConnStatus
}{
{
name: "ICE connected, relay connected, peer uses relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE connected, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE InProgress only, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.iceStatusConnecting = false
in.iceInProgress = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE down, relay up, peer uses relay -> partial",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusPartiallyConnected,
},
{
name: "ICE down, peer does NOT use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = false
in.iceStatusConnecting = true
},
// relayOK = false (peer uses relay but it's down), iceUp = true
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
// falls into default: Disconnected.
want: guard.ConnStatusDisconnected,
},
{
name: "ICE down, relay up but peer does not use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = true // not actually used since peer doesn't rely on it
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
in := base
tc.mutator(&in)
if got := evalConnStatus(in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
}
})
}
}

View File

@@ -10,7 +10,7 @@ const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY" EnvKeyNBForceRelay = "NB_FORCE_RELAY"
) )
func IsForceRelayed() bool { func isForceRelayed() bool {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
return true return true
} }

View File

@@ -8,19 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// ConnStatus represents the connection state as seen by the guard. type isConnectedFunc func() bool
type ConnStatus int
const (
// ConnStatusDisconnected means neither ICE nor Relay is connected.
ConnStatusDisconnected ConnStatus = iota
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
ConnStatusPartiallyConnected
// ConnStatusConnected means all required connections are established.
ConnStatusConnected
)
type connStatusFunc func() ConnStatus
// Guard is responsible for the reconnection logic. // Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues. // It will trigger to send an offer to the peer then has connection issues.
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
// - ICE candidate changes // - ICE candidate changes
type Guard struct { type Guard struct {
log *log.Entry log *log.Entry
isConnectedOnAllWay connStatusFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
srWatcher *SRWatcher srWatcher *SRWatcher
relayedConnDisconnected chan struct{} relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{} iCEConnDisconnected chan struct{}
} }
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
log: log, log: log,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
} }
} }
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity. // reconnectLoopWithRetry periodically check the connection status.
// // Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
// Behavior depends on the connection state reported by isConnectedOnAllWay:
// - Connected: no action, the peer is fully reachable.
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
//
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan) defer g.srWatcher.RemoveListener(srReconnectedChan)
@@ -89,47 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
tickerChannel := ticker.C tickerChannel := ticker.C
iceState := &iceRetryState{log: g.log}
defer iceState.reset()
for { for {
select { select {
case <-tickerChannel: case t := <-tickerChannel:
switch g.isConnectedOnAllWay() { if t.IsZero() {
case ConnStatusConnected: g.log.Infof("retry timed out, stop periodic offer sending")
// all good, nothing to do // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
case ConnStatusDisconnected: tickerChannel = make(<-chan time.Time)
callback() continue
case ConnStatusPartiallyConnected:
if iceState.shouldRetry() {
callback()
} else {
iceState.enterHourlyMode()
ticker.Stop()
tickerChannel = iceState.hourlyC()
}
} }
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.newReconnectTicker(ctx) ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-g.iCEConnDisconnected: case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker") g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.newReconnectTicker(ctx) ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-srReconnectedChan: case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker") g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.newReconnectTicker(ctx) ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-ctx.Done(): case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop") g.log.Debugf("context is done, stop reconnect loop")
@@ -152,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo) return backoff.NewTicker(bo)
} }
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1, RandomizationFactor: 0.1,

View File

@@ -1,61 +0,0 @@
package guard
import (
"time"
log "github.com/sirupsen/logrus"
)
const (
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
maxICERetries = 3
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
iceRetryInterval = 1 * time.Hour
)
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
// After maxICERetries attempts it switches to a periodic hourly retry.
type iceRetryState struct {
log *log.Entry
retries int
hourly *time.Ticker
}
func (s *iceRetryState) reset() {
s.retries = 0
if s.hourly != nil {
s.hourly.Stop()
s.hourly = nil
}
}
// shouldRetry reports whether the caller should send another ICE offer on this tick.
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
// to the hourly ticker via enterHourlyMode + hourlyC.
func (s *iceRetryState) shouldRetry() bool {
if s.hourly != nil {
s.log.Debugf("hourly ICE retry attempt")
return true
}
s.retries++
if s.retries <= maxICERetries {
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
return true
}
return false
}
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
func (s *iceRetryState) enterHourlyMode() {
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
s.hourly = time.NewTicker(iceRetryInterval)
}
func (s *iceRetryState) hourlyC() <-chan time.Time {
if s.hourly == nil {
return nil
}
return s.hourly.C
}

View File

@@ -1,103 +0,0 @@
package guard
import (
"testing"
log "github.com/sirupsen/logrus"
)
func newTestRetryState() *iceRetryState {
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
}
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
s := newTestRetryState()
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
}
}
}
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries; i++ {
_ = s.shouldRetry()
}
if s.shouldRetry() {
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
}
}
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
s := newTestRetryState()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
}
}
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
defer s.reset()
if s.hourlyC() == nil {
t.Fatalf("hourlyC returned nil after enterHourlyMode")
}
}
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
s := newTestRetryState()
s.enterHourlyMode()
defer s.reset()
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false in hourly mode, want true")
}
// Subsequent calls also return true — we keep retrying on each hourly tick.
if !s.shouldRetry() {
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
}
}
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
s.reset()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel after reset")
}
if s.retries != 0 {
t.Fatalf("retries = %d after reset, want 0", s.retries)
}
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
}
}
}
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
s := newTestRetryState()
s.reset()
s.reset() // second call must not panic or re-stop a nil ticker
if s.hourlyC() != nil {
t.Fatalf("hourlyC non-nil after double reset")
}
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw return srw
} }
func (w *SRWatcher) Start(disableICEMonitor bool) { func (w *SRWatcher) Start() {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
if !disableICEMonitor { iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) go iceMonitor.Start(ctx, w.onICEChanged)
go iceMonitor.Start(ctx, w.onICEChanged)
}
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"sync" "sync"
"sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -44,10 +43,6 @@ type OfferAnswer struct {
SessionID *ICESessionID SessionID *ICESessionID
} }
func (o *OfferAnswer) hasICECredentials() bool {
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
}
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
@@ -64,10 +59,6 @@ type Handshaker struct {
relayListener *AsyncOfferListener relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer) iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -75,7 +66,7 @@ type Handshaker struct {
} }
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
h := &Handshaker{ return &Handshaker{
log: log, log: log,
config: config, config: config,
signaler: signaler, signaler: signaler,
@@ -85,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer), remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer),
} }
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
} }
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -106,20 +90,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
// Record signaling received for reconnection attempts // Record signaling received for reconnection attempts
if h.metricsStages != nil { if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived() h.metricsStages.RecordSignalingReceived()
} }
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil { if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
} }
if h.iceListener != nil && h.RemoteICESupported() { if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer) h.iceListener(&remoteOfferAnswer)
} }
@@ -128,20 +110,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue continue
} }
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
// Record signaling received for reconnection attempts // Record signaling received for reconnection attempts
if h.metricsStages != nil { if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived() h.metricsStages.RecordSignalingReceived()
} }
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil { if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
} }
if h.iceListener != nil && h.RemoteICESupported() { if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer) h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
@@ -203,18 +183,15 @@ func (h *Handshaker) sendAnswer() error {
} }
func (h *Handshaker) buildOfferAnswer() OfferAnswer { func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{ answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort, WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(), Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr, RosenpassAddr: h.config.RosenpassConfig.Addr,
} SessionID: &sid,
if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
} }
if addr, err := h.relay.RelayInstanceAddress(); err == nil { if addr, err := h.relay.RelayInstanceAddress(); err == nil {
@@ -223,18 +200,3 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
return answer return answer
} }
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.hasICECredentials()
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -46,13 +46,9 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer // SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
var sessionIDBytes []byte sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if offerAnswer.SessionID != nil { if err != nil {
var err error log.Warnf("failed to get session ID bytes: %v", err)
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
} }
msg, err := signal.MarshalCredential( msg, err := signal.MarshalCredential(
s.wgPrivateKey, s.wgPrivateKey,

View File

@@ -1,10 +0,0 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -1,241 +0,0 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks" "github.com/netbirdio/netbird/client/net/hooks"
) )
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
} }
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
if nbnet.AdvancedRouting() { localRoutes, err := hasSeparateRouting()
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
if err != nil { if err != nil {
log.Errorf("Failed to get routes: %v", err) if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{} return false, netip.Prefix{}
} }

View File

@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil return []netip.Prefix{}, nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM. // GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil return []DetailedRoute{}, nil

View File

@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool { func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
return nil return nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) // GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) { func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS) log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,9 +25,6 @@ import (
const ( const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
) )
var routeProtoFlag int var routeProtoFlag int
@@ -44,42 +41,26 @@ func init() {
} }
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. // FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error { func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB() rib, err := retryFetchRIB()
if err != nil { if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err))) return fmt.Errorf("fetch routing table: %w", err)
} }
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil { if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err))) return fmt.Errorf("parse routing table: %w", err)
} }
var merr *multierror.Error
flushedCount := 0 flushedCount := 0
for _, msg := range msgs { for _, msg := range msgs {
@@ -136,12 +117,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix) return fmt.Errorf("invalid prefix: %s", prefix)
} }
msg, err := r.buildRouteMessage(action, prefix, nexthop) expBackOff := backoff.NewExponentialBackOff()
if err != nil { expBackOff.InitialInterval = 50 * time.Millisecond
return fmt.Errorf("build route message: %w", err) expBackOff.MaxInterval = 500 * time.Millisecond
} expBackOff.MaxElapsedTime = 1 * time.Second
if err := r.writeRouteMessage(msg, routeBudget); err != nil { if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add" a := "add"
if action == unix.RTM_DELETE { if action == unix.RTM_DELETE {
a = "remove" a = "remove"
@@ -151,91 +132,50 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil return nil
} }
// writeRouteMessage sends a route message over AF_ROUTE and waits for the func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
// kernel's matching reply, retrying transient failures until budget elapses. operation := func() error {
// Callers do not need to manage sockets or seq numbers themselves. fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
}
}()
tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
if err != nil { if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) { return fmt.Errorf("open routing socket: %w", err)
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline. }
continue defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
} }
return backoff.Permanent(fmt.Errorf("read: %w", err)) }()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
} }
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
continue msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
} }
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid) if _, err = unix.Write(fd, msgBytes); err != nil {
// uniquely identifies our own reply among broadcast traffic. if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid { return fmt.Errorf("write: %w", err)
continue }
return backoff.Permanent(fmt.Errorf("write: %w", err))
} }
if hdr.Errno != 0 {
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno))) respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
} }
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil return nil
} }
return operation
} }
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -243,7 +183,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action, Type: action,
Flags: unix.RTF_UP | routeProtoFlag, Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(), Seq: r.getSeq(),
} }
@@ -282,6 +221,19 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil return msg, nil
} }
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). // addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() { if addr.Is4() {

View File

@@ -1,5 +0,0 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin //go:build !linux && !windows
package net package net

24
client/net/env_android.go Normal file
View File

@@ -0,0 +1,24 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin //go:build !linux && !windows && !android
package net package net

View File

@@ -1,25 +0,0 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || windows //go:build windows
package net package net
@@ -24,22 +24,17 @@ func Init() {
} }
func checkAdvancedRoutingSupport() bool { func checkAdvancedRoutingSupport() bool {
legacyRouting := false var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" { if val := os.Getenv(envUseLegacyRouting); val != "" {
parsed, err := strconv.ParseBool(val) legacyRouting, err = strconv.ParseBool(val)
if err != nil { if err != nil {
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err) log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
} else {
legacyRouting = parsed
} }
} }
if legacyRouting { if legacyRouting || netstack.IsEnabled() {
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting) log.Info("advanced routing has been requested to be disabled")
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
return false return false
} }

View File

@@ -1,5 +0,0 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin //go:build !linux && !windows
package net package net

View File

@@ -1,160 +0,0 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -15,8 +15,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -40,6 +38,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -306,7 +305,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store) peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl) settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager) jobManager := job.NewJobManager(nil, store, peersManager)

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@@ -137,8 +138,10 @@ func restoreResidualState(ctx context.Context, statePath string) error {
} }
// clean up any remaining routes independently of the state file // clean up any remaining routes independently of the state file
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { if !nbnet.AdvancedRouting() {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)

View File

@@ -187,23 +187,24 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err) return "", fmt.Errorf("get NetBird executable path: %w", err)
} }
hostList := strings.Join(deduplicatedPatterns, ",") hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath) config := fmt.Sprintf("Host %s\n", hostLine)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n" config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PasswordAuthentication yes\n" config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PubkeyAuthentication yes\n" config += " PasswordAuthentication yes\n"
config += " BatchMode no\n" config += " PubkeyAuthentication yes\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) config += " BatchMode no\n"
config += " StrictHostKeyChecking no\n" config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n" config += " UserKnownHostsFile NUL\n"
} else { } else {
config += " UserKnownHostsFile /dev/null\n" config += " UserKnownHostsFile /dev/null\n"
} }
config += " CheckHostIP no\n" config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n" config += " LogLevel ERROR\n\n"
return config, nil return config, nil
} }

View File

@@ -116,37 +116,6 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
} }
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) { func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable // Set force environment variable
t.Setenv(EnvForceSSHConfig, "true") t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -2,6 +2,7 @@ package system
import ( import (
"context" "context"
"net"
"net/netip" "net/netip"
"strings" "strings"
@@ -144,6 +145,59 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v return v
} }
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks. // GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks)) log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,8 +2,6 @@ package system
import ( import (
"context" "context"
"net"
"net/netip"
"runtime" "runtime"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -44,66 +42,6 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path. // checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) { func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil return []File{}, nil

View File

@@ -1,66 +0,0 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

View File

@@ -119,8 +119,6 @@ server:
# Reverse proxy settings (optional) # Reverse proxy settings (optional)
# reverseProxy: # reverseProxy:
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"]) # trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies) # trustedHTTPProxiesCount: 0
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"]) # trustedPeers: []
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.

View File

@@ -457,18 +457,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err) require.NoError(t, err)
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
// which is when Receive has actually returned. Without this, the Receive goroutine
// can outlive the test and call t.Logf after teardown, panicking.
receiveDone := make(chan struct{})
t.Cleanup(func() {
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Error("Receive goroutine did not exit after Close")
}
})
t.Cleanup(func() { t.Cleanup(func() {
err := client.Close() err := client.Close()
assert.NoError(t, err, "failed to close flow") assert.NoError(t, err, "failed to close flow")
@@ -480,7 +468,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
receivedAfterReconnect := make(chan struct{}) receivedAfterReconnect := make(chan struct{})
go func() { go func() {
defer close(receiveDone)
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 { if msg.IsInitiator || len(msg.EventId) == 0 {
return nil return nil

4
go.mod
View File

@@ -71,7 +71,7 @@ require (
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2 github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
@@ -323,5 +323,3 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

8
go.sum
View File

@@ -400,6 +400,8 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
@@ -447,14 +449,12 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg= github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "Registering CrowdSec bouncer..." echo "Registering CrowdSec bouncer..."
local cs_retries=0 local cs_retries=0
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
cs_retries=$((cs_retries + 1)) cs_retries=$((cs_retries + 1))
if [[ $cs_retries -ge 30 ]]; then if [[ $cs_retries -ge 30 ]]; then
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr

View File

@@ -16,6 +16,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -35,15 +38,17 @@ type Manager interface {
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
accountManager account.Manager accountManager account.Manager
networkMapController network_map.Controller networkMapController network_map.Controller
} }
func NewManager(store store.Store) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager,
} }
} }
@@ -60,10 +65,28 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
} }
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
} }

View File

@@ -5,11 +5,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/shared/auth" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
) )
@@ -18,15 +15,21 @@ type handler struct {
manager accesslogs.Manager manager accesslogs.Manager
} }
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) { func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS") router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
} }
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var filter accesslogs.AccessLogFilter var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r) filter.ParseFromRequest(r)

View File

@@ -9,19 +9,25 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
geo geolocation.Geolocation permissionsManager permissions.Manager
cleanupCancel context.CancelFunc geo geolocation.Geolocation
cleanupCancel context.CancelFunc
} }
func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
geo: geo, permissionsManager: permissionsManager,
geo: geo,
} }
} }
@@ -57,6 +63,14 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering // GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) { func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}
if !ok {
return nil, 0, status.NewPermissionDeniedError()
}
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil { if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err) log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
} }

View File

@@ -6,11 +6,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/shared/auth" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -20,15 +17,15 @@ type handler struct {
manager Manager manager Manager
} }
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) { func RegisterEndpoints(router *mux.Router, manager Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS") router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS") router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS") router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
} }
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
@@ -59,7 +56,13 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
return resp return resp
} }
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId) domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -74,7 +77,13 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, ret) util.WriteJSONObject(r.Context(), w, ret)
} }
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiReverseProxiesDomainsJSONRequestBody var req api.PostApiReverseProxiesDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -90,7 +99,13 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, use
util.WriteJSONObject(r.Context(), w, domainToApi(domain)) util.WriteJSONObject(r.Context(), w, domainToApi(domain))
} }
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"] domainID := mux.Vars(r)["domainId"]
if domainID == "" { if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
@@ -105,7 +120,13 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, use
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"] domainID := mux.Vars(r)["domainId"]
if domainID == "" { if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)

View File

@@ -11,7 +11,11 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
type store interface { type store interface {
@@ -33,22 +37,32 @@ type proxyManager interface {
} }
type Manager struct { type Manager struct {
store store store store
validator domain.Validator validator domain.Validator
proxyManager proxyManager proxyManager proxyManager
accountManager account.Manager permissionsManager permissions.Manager
accountManager account.Manager
} }
func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager { func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
return Manager{ return Manager{
store: store, store: store,
proxyManager: proxyMgr, proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver}, validator: domain.Validator{Resolver: net.DefaultResolver},
accountManager: accountManager, permissionsManager: permissionsManager,
accountManager: accountManager,
} }
} }
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID) domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err) return nil, fmt.Errorf("list custom domains: %w", err)
@@ -104,6 +118,14 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
} }
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) { func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters // Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil { if err != nil {
@@ -137,6 +159,14 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
} }
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error { func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
d, err := m.store.GetCustomDomain(ctx, accountID, domainID) d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil { if err != nil {
return fmt.Errorf("get domain from store: %w", err) return fmt.Errorf("get domain from store: %w", err)
@@ -153,6 +183,21 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
} }
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) { func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,
"domainID": domainID, "domainID": domainID,

View File

@@ -23,7 +23,7 @@ type Proxy struct {
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time ConnectedAt *time.Time
DisconnectedAt *time.Time DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"` Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time

View File

@@ -6,14 +6,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -32,19 +30,25 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
} }
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager) domainmanager.RegisterEndpoints(domainRouter, domainManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager) accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
} }
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId) allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -59,7 +63,13 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAut
util.WriteJSONObject(r.Context(), w, apiServices) util.WriteJSONObject(r.Context(), w, apiServices)
} }
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ServiceRequest var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -67,13 +77,12 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
} }
service := new(rpservice.Service) service := new(rpservice.Service)
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
if err := service.Validate(); err != nil { if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -87,7 +96,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
} }
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"] serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" { if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -103,7 +118,13 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
} }
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"] serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" { if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -118,13 +139,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
service := new(rpservice.Service) service := new(rpservice.Service)
service.ID = serviceID service.ID = serviceID
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil { if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
if err := service.Validate(); err != nil { if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -138,7 +158,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
} }
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"] serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" { if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -153,7 +179,13 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@@ -85,17 +86,18 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
} }
mgr := &Manager{ mgr := &Manager{
store: testStore, store: testStore,
accountManager: accountMgr, accountManager: accountMgr,
proxyController: mockCtrl, permissionsManager: permissions.NewManager(testStore),
capabilities: mockCaps, proxyController: mockCtrl,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
} }
mgr.exposeReaper = &exposeReaper{manager: mgr} mgr.exposeReaper = &exposeReaper{manager: mgr}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -21,6 +22,9 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -79,22 +83,26 @@ type CapabilityProvider interface {
} }
type Manager struct { type Manager struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
proxyController proxy.Controller permissionsManager permissions.Manager
capabilities CapabilityProvider proxyController proxy.Controller
clusterDeriver ClusterDeriver networkMapController network_map.Controller
exposeReaper *exposeReaper capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
} }
// NewManager creates a new service manager. // NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver, networkMapController network_map.Controller) *Manager {
mgr := &Manager{ mgr := &Manager{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
proxyController: proxyController, permissionsManager: permissionsManager,
capabilities: capabilities, proxyController: proxyController,
clusterDeriver: clusterDeriver, networkMapController: networkMapController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
} }
mgr.exposeReaper = &exposeReaper{manager: mgr} mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr return mgr
@@ -107,10 +115,26 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
// GetActiveClusters returns all active proxy clusters with their connected proxy count. // GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx) return m.store.GetActiveProxyClusters(ctx)
} }
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err) return nil, fmt.Errorf("failed to get services: %w", err)
@@ -130,13 +154,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
for _, target := range s.Targets { for _, target := range s.Targets {
switch target.TargetType { switch target.TargetType {
case service.TargetTypePeer: case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) target.Host = m.getPeerTargetHost(ctx, accountID, target)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case service.TargetTypeHost: case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
@@ -163,7 +181,35 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
return nil return nil
} }
func (m *Manager) getPeerTargetHost(ctx context.Context, accountID string, target *service.Target) string {
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, target.ServiceID, err)
return unknownHostPlaceholder
}
if target.Protocol == "https" {
settings, err := m.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
log.WithContext(ctx).Warnf("failed to get account settings for service %s: %v", target.ServiceID, err)
return unknownHostPlaceholder
}
dnsDomain := m.networkMapController.GetDNSDomain(settings)
return peer.FQDN(dnsDomain)
}
return peer.IP.String()
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err) return nil, fmt.Errorf("failed to get service: %w", err)
@@ -177,6 +223,14 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
} }
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) { func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil { if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err return nil, err
} }
@@ -187,7 +241,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta()) m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err := m.replaceHostByLookup(ctx, accountID, s) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
@@ -454,6 +508,14 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
} }
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) { func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil { if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err) return nil, fmt.Errorf("hash secrets: %w", err)
} }
@@ -740,8 +802,16 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
} }
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var s *service.Service var s *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
@@ -772,8 +842,16 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
} }
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error { func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service var services []*service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID) services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
@@ -1058,7 +1136,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
} }
groupIDs := make([]string, 0, len(groupNames)) groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames { for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID) g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err) return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
} }

View File

@@ -23,6 +23,9 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -697,10 +700,12 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID) err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err) require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
} }
@@ -713,9 +718,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
require.NoError(t, err) require.NoError(t, err)
mgr := &Manager{ mgr := &Manager{
store: testStore, store: testStore,
accountManager: accountMgr, accountManager: accountMgr,
proxyController: proxyController, permissionsManager: permsMgr,
proxyController: proxyController,
clusterDeriver: &testClusterDeriver{ clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"}, domains: []string{"test.netbird.io"},
}, },
@@ -1124,6 +1130,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
@@ -1134,9 +1141,10 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
mgr := &Manager{ mgr := &Manager{
store: sqlStore, store: sqlStore,
accountManager: mockAcct, permissionsManager: mockPerms,
proxyController: proxyController, accountManager: mockAcct,
proxyController: proxyController,
} }
service := &rpservice.Service{ service := &rpservice.Service{
@@ -1159,6 +1167,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion") require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT(). mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT(). mockAcct.EXPECT().

View File

@@ -6,11 +6,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/shared/auth" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -20,19 +17,25 @@ type handler struct {
manager zones.Manager manager zones.Manager
} }
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) { func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS") router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
} }
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId) allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@@ -47,7 +50,13 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *
util.WriteJSONObject(r.Context(), w, apiZones) util.WriteJSONObject(r.Context(), w, apiZones)
} }
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -57,7 +66,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *a
zone := new(zones.Zone) zone := new(zones.Zone)
zone.FromAPIRequest(&req) zone.FromAPIRequest(&req)
if err := zone.Validate(); err != nil { if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -71,7 +80,13 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
} }
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -87,7 +102,13 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
} }
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -95,7 +116,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
} }
var req api.PutApiDnsZonesZoneIdJSONRequestBody var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
@@ -104,7 +125,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
zone.FromAPIRequest(&req) zone.FromAPIRequest(&req)
zone.ID = zoneID zone.ID = zoneID
if err := zone.Validate(); err != nil { if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -118,14 +139,20 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *a
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
} }
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return return
} }
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -7,34 +7,62 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
dnsDomain string permissionsManager permissions.Manager
dnsDomain string
} }
func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager { func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
dnsDomain: dnsDomain, permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
} }
} }
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) { func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID) return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) { func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID) return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
} }
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) { func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
var err error ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil { if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err return nil, err
} }
@@ -74,6 +102,14 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
} }
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) { func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID) zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err) return nil, fmt.Errorf("failed to get zone: %w", err)
@@ -114,6 +150,14 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
} }
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error { func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -13,6 +13,9 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -26,7 +29,7 @@ const (
testDNSDomain = "netbird.selfhosted" testDNSDomain = "netbird.selfhosted"
) )
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) { func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper() t.Helper()
ctx := context.Background() ctx := context.Background()
@@ -46,17 +49,23 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{} mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl) manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, ctrl, cleanup return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
} }
func TestManagerImpl_GetAllZones(t *testing.T) { func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t) manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -68,6 +77,10 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
err = testStore.CreateZone(ctx, zone2) err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID) result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, result, 2) assert.Len(t, result, 2)
@@ -75,13 +88,43 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
assert.Equal(t, zone2.ID, result[1].ID) assert.Equal(t, zone2.ID, result[1].ID)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
} }
func TestManagerImpl_GetZone(t *testing.T) { func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t) manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -89,6 +132,10 @@ func TestManagerImpl_GetZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone) err := testStore.CreateZone(ctx, zone)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID) result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID) assert.Equal(t, zone.ID, result.ID)
@@ -96,13 +143,29 @@ func TestManagerImpl_GetZone(t *testing.T) {
assert.Equal(t, zone.Domain, result.Domain) assert.Equal(t, zone.Domain, result.Domain)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
} }
func TestManagerImpl_CreateZone(t *testing.T) { func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, ctrl, cleanup := setupTest(t) manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -114,6 +177,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -132,8 +199,31 @@ func TestManagerImpl_CreateZone(t *testing.T) {
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups) assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) { t.Run("invalid group", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t) manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -143,13 +233,17 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{"invalid-group"}, DistributionGroups: []string{"invalid-group"},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
}) })
t.Run("duplicate domain", func(t *testing.T) { t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t) manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -165,6 +259,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -175,7 +273,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
}) })
t.Run("peer DNS domain conflict", func(t *testing.T) { t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t) manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -193,6 +291,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -203,7 +305,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
}) })
t.Run("default DNS domain conflict", func(t *testing.T) { t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t) manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -215,6 +317,10 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -229,7 +335,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t) manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -246,6 +352,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -265,7 +375,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
}) })
t.Run("domain change not allowed", func(t *testing.T) { t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, ctrl, cleanup := setupTest(t) manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -282,6 +392,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID}, DistributionGroups: []string{testGroupID},
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -291,8 +405,31 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
assert.Equal(t, status.InvalidArgument, s.Type()) assert.Equal(t, status.InvalidArgument, s.Type())
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) { t.Run("zone not found", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t) manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -302,6 +439,10 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
Domain: "example.com", Domain: "example.com",
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -312,7 +453,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success with records", func(t *testing.T) { t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t) manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -328,6 +469,10 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2) err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0 storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++ storeEventCallCount++
@@ -348,7 +493,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
}) })
t.Run("success without records", func(t *testing.T) { t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t) manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -356,6 +501,10 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone) err := testStore.CreateZone(ctx, zone)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -373,11 +522,31 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
t.Run("zone not found", func(t *testing.T) { t.Run("permission denied", func(t *testing.T) {
manager, _, _, ctrl, cleanup := setupTest(t) manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone") err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err) require.Error(t, err)
}) })

View File

@@ -6,11 +6,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/shared/auth" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -20,19 +17,25 @@ type handler struct {
manager records.Manager manager records.Manager
} }
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) { func RegisterEndpoints(router *mux.Router, manager records.Manager) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS") router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
} }
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -53,7 +56,13 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, apiRecords) util.WriteJSONObject(r.Context(), w, apiRecords)
} }
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -69,7 +78,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth
record := new(records.Record) record := new(records.Record)
record.FromAPIRequest(&req) record.FromAPIRequest(&req)
if err := record.Validate(); err != nil { if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -83,7 +92,13 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
} }
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -105,7 +120,13 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *au
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
} }
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -119,7 +140,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
} }
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
@@ -128,7 +149,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
record.FromAPIRequest(&req) record.FromAPIRequest(&req)
record.ID = recordID record.ID = recordID
if err := record.Validate(); err != nil { if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return return
} }
@@ -142,7 +163,13 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
} }
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"] zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" { if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -155,7 +182,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth
return return
} }
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@@ -9,36 +9,64 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager
} }
func NewManager(store store.Store, accountManager account.Manager) records.Manager { func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager,
} }
} }
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) { func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID) return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
} }
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) { func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID) return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
} }
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) { func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL) record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)
@@ -73,11 +101,18 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
} }
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) { func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone var zone *zones.Zone
var record *records.Record var record *records.Record
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)
@@ -125,11 +160,18 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
} }
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error { func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record var record *records.Record
var zone *zones.Zone var zone *zones.Zone
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get zone: %w", err) return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -12,8 +12,12 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -23,7 +27,7 @@ const (
testGroupID = "test-group-id" testGroupID = "test-group-id"
) )
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) { func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper() t.Helper()
ctx := context.Background() ctx := context.Background()
@@ -47,17 +51,22 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{} mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := NewManager(testStore, mockAccountManager).(*managerImpl) manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, ctrl, cleanup return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
} }
func TestManagerImpl_GetAllRecords(t *testing.T) { func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -69,6 +78,10 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2) err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, result, 2) assert.Len(t, result, 2)
@@ -76,13 +89,43 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
assert.Equal(t, record2.ID, result[1].ID) assert.Equal(t, record2.ID, result[1].ID)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
} }
func TestManagerImpl_GetRecord(t *testing.T) { func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -90,6 +133,10 @@ func TestManagerImpl_GetRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record) err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID) result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, record.ID, result.ID) assert.Equal(t, record.ID, result.ID)
@@ -99,13 +146,29 @@ func TestManagerImpl_GetRecord(t *testing.T) {
assert.Equal(t, record.TTL, result.TTL) assert.Equal(t, record.TTL, result.TTL)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
} }
func TestManagerImpl_CreateRecord(t *testing.T) { func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success - A record", func(t *testing.T) { t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t) manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -116,6 +179,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -135,7 +202,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("success - AAAA record", func(t *testing.T) { t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t) manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -146,6 +213,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 600, TTL: 600,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -160,7 +231,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("success - CNAME record", func(t *testing.T) { t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t) manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -171,6 +242,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID) assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID) assert.Equal(t, testAccountID, accountID)
@@ -184,8 +259,32 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
assert.Equal(t, inputRecord.Content, result.Content) assert.Equal(t, inputRecord.Content, result.Content)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) { t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, ctrl, cleanup := setupTest(t) manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -196,6 +295,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -203,7 +306,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("duplicate record", func(t *testing.T) { t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -218,6 +321,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -225,7 +332,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
}) })
t.Run("CNAME conflict with existing A record", func(t *testing.T) { t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -240,6 +347,10 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -251,7 +362,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -267,6 +378,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Changed TTL TTL: 600, // Changed TTL
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -285,7 +400,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
}) })
t.Run("update only TTL - no validation", func(t *testing.T) { t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -301,6 +416,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Only TTL changed TTL: 600, // Only TTL changed
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored // Event should be stored
} }
@@ -311,8 +430,33 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
assert.Equal(t, 600, result.TTL) assert.Equal(t, 600, result.TTL)
}) })
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) { t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, ctrl, cleanup := setupTest(t) manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -324,13 +468,17 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, TTL: 600,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
}) })
t.Run("update creates duplicate", func(t *testing.T) { t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, ctrl, cleanup := setupTest(t) manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -350,6 +498,10 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 300, TTL: 300,
} }
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err) require.Error(t, err)
assert.Nil(t, result) assert.Nil(t, result)
@@ -361,7 +513,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background() ctx := context.Background()
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t) manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
@@ -369,6 +521,10 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record) err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err) require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true storeEventCalled = true
@@ -386,11 +542,31 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
require.Error(t, err) require.Error(t, err)
}) })
t.Run("record not found", func(t *testing.T) { t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, ctrl, cleanup := setupTest(t) manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup() defer cleanup()
defer ctrl.Finish() defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record") err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err) require.Error(t, err)
}) })

View File

@@ -30,7 +30,6 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -110,7 +109,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -118,15 +117,6 @@ func (s *BaseServer) APIHandler() http.Handler {
}) })
} }
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server { return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers trustedPeers := s.Config.ReverseProxy.TrustedPeers
@@ -233,7 +223,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
func (s *BaseServer) AccessLogsManager() accesslogs.Manager { func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager { return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager()) accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
accessLogManager.StartPeriodicCleanup( accessLogManager.StartPeriodicCleanup(
context.Background(), context.Background(),
s.Config.ReverseProxy.AccessLogRetentionDays, s.Config.ReverseProxy.AccessLogRetentionDays,

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
@@ -28,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
) )
@@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager {
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
} }
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig) return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
}) })
} }
func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager { return Create(s, func() peers.Manager {
manager := peers.NewManager(s.Store()) manager := peers.NewManager(s.Store(), s.PermissionsManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
manager.SetNetworkMapController(s.NetworkMapController()) manager.SetNetworkMapController(s.NetworkMapController())
manager.SetIntegratedPeerValidator(s.IntegratedValidator()) manager.SetIntegratedPeerValidator(s.IntegratedValidator())
@@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
func (s *BaseServer) GroupsManager() groups.Manager { func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager { return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.AccountManager()) return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
}) })
} }
func (s *BaseServer) ResourcesManager() resources.Manager { func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager { return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager()) return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
}) })
} }
func (s *BaseServer) RoutesManager() routers.Manager { func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager { return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.AccountManager()) return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
}) })
} }
func (s *BaseServer) NetworksManager() networks.Manager { func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager { return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager()) return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
}) })
} }
func (s *BaseServer) ZonesManager() zones.Manager { func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager { return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain()) return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
}) })
} }
func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager { return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager()) return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
}) })
} }
func (s *BaseServer) ServiceManager() service.Manager { func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager { return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager(), s.NetworkMapController())
}) })
} }
@@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager { return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager()) m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
return &m return &m
}) })
} }

View File

@@ -15,7 +15,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -40,6 +39,9 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -280,14 +282,22 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account. // User that performs the update has to belong to the account.
// Returns an updated Settings // Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
var oldSettings *types.Settings var oldSettings *types.Settings
var updateAccountPeers bool var updateAccountPeers bool
var groupChangesAffectPeers bool var groupChangesAffectPeers bool
var reloadReverseProxy bool var reloadReverseProxy bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool var groupsUpdated bool
var err error
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
@@ -715,6 +725,15 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err return err
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users)) userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
@@ -957,10 +976,6 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
return nil, err return nil, err
} }
if user.AccountID != accountID {
return nil, fmt.Errorf("user %s does not belong to account %s", userID, accountID)
}
key := user.IntegrationReference.CacheKey(accountID, userID) key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key) ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil { if err != nil {
@@ -1272,16 +1287,41 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID. // GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccount(ctx, accountID) return am.Store.GetAccount(ctx, accountID)
} }
// GetAccountMeta returns the account metadata associated with this account ID. // GetAccountMeta returns the account metadata associated with this account ID.
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account. // GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
log.Errorf("failed to get account onboarding for account %s: %v", accountID, err) log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
@@ -1298,6 +1338,15 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
} }
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
return nil, fmt.Errorf("failed to get account onboarding: %w", err) return nil, fmt.Errorf("failed to get account onboarding: %w", err)
@@ -1356,8 +1405,9 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil return accountID, user.Id, nil
} }
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
// User account association is already validated above by GetUserByUserID return "", "", err
}
if !user.IsServiceUser && userAuth.Invited { if !user.IsServiceUser && userAuth.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, accountID, user.Id)
@@ -1799,6 +1849,13 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
} }
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
@@ -2140,6 +2197,14 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
} }
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error { func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP) updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil { if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err) return fmt.Errorf("update peer IP transaction: %w", err)

View File

@@ -60,7 +60,7 @@ type Manager interface {
GetUserByID(ctx context.Context, id string) (*types.User, error) GetUserByID(ctx context.Context, id string) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
} }
// GetGroupByName mocks base method. // GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID) ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret0, _ := ret[0].(*types.Group) ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetGroupByName indicates an expected call of GetGroupByName. // GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
} }
// GetIdentityProvider mocks base method. // GetIdentityProvider mocks base method.
@@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo
} }
// GetPeers mocks base method. // GetPeers mocks base method.
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) { func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all) ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter)
ret0, _ := ret[0].([]*peer.Peer) ret0, _ := ret[0].([]*peer.Peer)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetPeers indicates an expected call of GetPeers. // GetPeers indicates an expected call of GetPeers.
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter)
} }
// GetPolicy mocks base method. // GetPolicy mocks base method.

View File

@@ -22,10 +22,8 @@ import (
"go.opentelemetry.io/otel/metric/noop" "go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/shared/management/status"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -51,6 +49,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -2312,29 +2311,6 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
} }
} }
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
ctx := context.Background()
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanUp)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
// Verify the already-expired peer is excluded at the store level
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
}
// Only the non-expired peer with expiration enabled should be returned
require.Len(t, peers, 1)
assert.Equal(t, "notexpired01", peers[0].ID)
}
func TestAccount_GetInactivePeers(t *testing.T) { func TestAccount_GetInactivePeers(t *testing.T) {
type test struct { type test struct {
name string name string
@@ -3148,7 +3124,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store) peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl) proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT(). proxyManager.EXPECT().
@@ -3165,7 +3141,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store) requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{}) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -3176,7 +3152,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil)) manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
return manager, updateManager, nil return manager, updateManager, nil
} }

View File

@@ -8,6 +8,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
@@ -20,6 +22,14 @@ const (
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
@@ -29,11 +39,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool var updateAccountPeers bool
var eventsToStore []func() var eventsToStore []func()
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err return err
} }

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -28,6 +28,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -78,6 +79,16 @@ func TestGetDNSSettings(t *testing.T) {
if len(dnsSettings.DisabledManagementGroups) != 1 { if len(dnsSettings.DisabledManagementGroups) != 1 {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
} }
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
s, ok := status.FromError(err)
if !ok && s.Type() != status.PermissionDenied {
t.Errorf("returned error should be Permission Denied, got err: %s", err)
}
} }
func TestSaveDNSSettings(t *testing.T) { func TestSaveDNSSettings(t *testing.T) {
@@ -212,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store) peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background() ctx := context.Background()
@@ -223,7 +234,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
updateManager := update_channel.NewPeersUpdateManager(metrics) updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store) requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{}) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
} }

View File

@@ -9,8 +9,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func isEnabled() bool { func isEnabled() bool {
@@ -20,6 +23,14 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account // GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true) events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -12,6 +12,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
@@ -30,24 +32,13 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
// This method is called from authenticated/authorized handlers, so we just validate
// that the user exists and is part of the account
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return err return err
} }
if user == nil { if !allowed {
return status.NewUserNotFoundError(userID) return status.NewPermissionDeniedError()
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsBlocked() {
return status.NewUserBlockedError()
} }
return nil return nil
@@ -70,17 +61,27 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// CreateGroup object of the peers // CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
@@ -124,11 +125,19 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
// UpdateGroup object of the peers // UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
@@ -187,24 +196,33 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. // This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
var globalErr error var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
newGroup.AccountID = accountID newGroup.AccountID = accountID
if err := transaction.CreateGroup(ctx, newGroup); err != nil { if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err return err
} }
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil { err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err return err
} }
@@ -225,7 +243,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
} }
} }
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil { if err != nil {
return err return err
@@ -247,14 +264,21 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
// It is the caller's responsibility to ensure proper locking is in place before invoking this method. // It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. // This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func() var eventsToStore []func()
var updateAccountPeers bool var updateAccountPeers bool
var globalErr error var globalErr error
groupIDs := make([]string, 0, len(groups)) groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups { for _, newGroup := range groups {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err return err
} }
@@ -287,7 +311,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
} }
} }
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil { if err != nil {
return err return err
@@ -393,6 +416,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups. // If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end. // Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var allErrors error var allErrors error
var groupIDsToDelete []string var groupIDsToDelete []string
var deletedGroups []*types.Group var deletedGroups []*types.Group

View File

@@ -26,6 +26,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -763,10 +764,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// Saving a group linked to network router should update account peers and send peer update // Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) {
groupsManager := groups.NewManager(manager.Store, manager) permissionsManager := permissions.NewManager(manager.Store)
resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
routersManager := routers.NewManager(manager.Store, manager) resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{ network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
ID: "network_test", ID: "network_test",

Some files were not shown because too many files have changed in this diff Show More