mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Compare commits
10 Commits
fix/manage
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5da05ecca6 | ||
|
|
801de8c68d | ||
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c |
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||||
|
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||||
|
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||||
|
// versions, which returns EPERM to any other process that tries to insert
|
||||||
|
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||||
|
// itself add the accept rules to its own chains by trusting the interface.
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||||
|
// should bypass firewalld filtering.
|
||||||
|
const TrustedZone = "trusted"
|
||||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/godbus/dbus/v5"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dbusDest = "org.fedoraproject.FirewallD1"
|
||||||
|
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||||
|
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||||
|
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||||
|
|
||||||
|
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||||
|
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||||
|
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||||
|
errNotEnabled = "NOT_ENABLED"
|
||||||
|
|
||||||
|
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||||
|
// A fresh context is created for each call so a slow DBus probe can't
|
||||||
|
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||||
|
callTimeout = 3 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||||
|
|
||||||
|
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||||
|
// Info level only for the first successful add per process; repeat adds
|
||||||
|
// from other init paths are quieter.
|
||||||
|
trustLogOnce sync.Once
|
||||||
|
|
||||||
|
parentCtxMu sync.RWMutex
|
||||||
|
parentCtx context.Context = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetParentContext installs a parent context whose cancellation aborts any
|
||||||
|
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||||
|
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||||
|
// during engine shutdown when the engine context is already cancelled.
|
||||||
|
func SetParentContext(ctx context.Context) {
|
||||||
|
parentCtxMu.Lock()
|
||||||
|
parentCtx = ctx
|
||||||
|
parentCtxMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getParentContext() context.Context {
|
||||||
|
parentCtxMu.RLock()
|
||||||
|
defer parentCtxMu.RUnlock()
|
||||||
|
return parentCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||||
|
// running. It is idempotent and best-effort: errors are returned so callers
|
||||||
|
// can log, but a non-running firewalld is not an error. Only the first
|
||||||
|
// successful call per process logs at Info. Respects the parent context set
|
||||||
|
// via SetParentContext so startup-time cancellation unblocks it.
|
||||||
|
func TrustInterface(iface string) error {
|
||||||
|
parent := getParentContext()
|
||||||
|
if !isRunning(parent) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := addTrusted(parent, iface); err != nil {
|
||||||
|
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||||
|
}
|
||||||
|
trustLogOnce.Do(func() {
|
||||||
|
log.Infof("added %s to firewalld trusted zone", iface)
|
||||||
|
})
|
||||||
|
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||||
|
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||||
|
// during shutdown after the engine context has been cancelled.
|
||||||
|
func UntrustInterface(iface string) error {
|
||||||
|
if !isRunning(context.Background()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||||
|
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(parent, callTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunning(parent context.Context) bool {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
ok, err := isRunningDBus(ctx)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return isRunningCLI(ctx)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func addTrusted(parent context.Context, iface string) error {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
err := addDBus(ctx, iface)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errDBusUnavailable) {
|
||||||
|
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return addCLI(ctx, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeTrusted(parent context.Context, iface string) error {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
err := removeDBus(ctx, iface)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errDBusUnavailable) {
|
||||||
|
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return removeCLI(ctx, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
var zone string
|
||||||
|
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||||
|
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunningCLI(ctx context.Context) bool {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDBus(ctx context.Context, iface string) error {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||||
|
if call.Err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||||
|
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||||
|
if move.Err != nil {
|
||||||
|
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeDBus(ctx context.Context, iface string) error {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||||
|
if call.Err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addCLI(ctx context.Context, iface string) error {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --change-interface (no --permanent) binds the interface for the
|
||||||
|
// current runtime only; we do not want membership to persist across
|
||||||
|
// reboots because netbird re-asserts it on every startup.
|
||||||
|
out, err := exec.CommandContext(ctx,
|
||||||
|
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||||
|
).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeCLI(ctx context.Context, iface string) error {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.CommandContext(ctx,
|
||||||
|
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||||
|
).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
msg := strings.TrimSpace(string(out))
|
||||||
|
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbusErrContains(err error, code string) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var de dbus.Error
|
||||||
|
if errors.As(err, &de) {
|
||||||
|
for _, b := range de.Body {
|
||||||
|
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Contains(err.Error(), code)
|
||||||
|
}
|
||||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/godbus/dbus/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDBusErrContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
code string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil error", nil, errZoneAlreadySet, false},
|
||||||
|
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||||
|
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||||
|
{
|
||||||
|
"dbus.Error body match",
|
||||||
|
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||||
|
errZoneAlreadySet,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dbus.Error body miss",
|
||||||
|
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||||
|
errAlreadyEnabled,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dbus.Error non-string body falls back to Error()",
|
||||||
|
dbus.Error{Name: "x", Body: []any{123}},
|
||||||
|
"x",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := dbusErrContains(tc.err, tc.code)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func SetParentContext(context.Context) {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func TrustInterface(string) error {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func UntrustInterface(string) error {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||||
|
// interface in firewalld's trusted zone without a corresponding Close.
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||||
|
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
// attempt to delete state only if all other operations succeeded
|
// attempt to delete state only if all other operations succeeded
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
@@ -40,6 +41,8 @@ const (
|
|||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
|
firewalldTableName = "firewalld"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
userDataAcceptInputRule = "inputaccept"
|
userDataAcceptInputRule = "inputaccept"
|
||||||
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatPreroutingRules(); err != nil {
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
}
|
}
|
||||||
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
|
|||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to refresh rules: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||||
|
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||||
|
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||||
|
if chain.Table.Name == firewalldTableName {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Skip all iptables-managed tables in the ip family
|
// Skip all iptables-managed tables in the ip family
|
||||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,6 +3,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AllowNetbird()
|
return m.nativeFirewall.AllowNetbird()
|
||||||
}
|
}
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
|
Name() string
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
|
|||||||
@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
|
NameFunc func() string
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) Name() string {
|
||||||
|
if i.NameFunc == nil {
|
||||||
|
return "wgtest"
|
||||||
|
}
|
||||||
|
return i.NameFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
if i.GetWGDeviceFunc == nil {
|
if i.GetWGDeviceFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
|||||||
ipv6Count++
|
ipv6Count++
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
// routing-correctness checks above are the real assertions; the counts
|
||||||
|
// are a sanity bound to catch a totally silent path.
|
||||||
|
minDelivered := packetsPerFamily * 80 / 100
|
||||||
|
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||||
|
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package debug
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Skip("Skipping upload test on docker ci")
|
t.Skip("Skipping upload test on docker ci")
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
testURL := "http://localhost:8080"
|
addr := reserveLoopbackPort(t)
|
||||||
|
testURL := "http://" + addr
|
||||||
t.Setenv("SERVER_URL", testURL)
|
t.Setenv("SERVER_URL", testURL)
|
||||||
|
t.Setenv("SERVER_ADDRESS", addr)
|
||||||
t.Setenv("STORE_DIR", testDir)
|
t.Setenv("STORE_DIR", testDir)
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Errorf("Failed to stop server: %v", err)
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
waitForServer(t, addr)
|
||||||
|
|
||||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, fileContent, createdFileContent)
|
require.Equal(t, fileContent, createdFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||||
|
// address, then releases it so the server under test can rebind. The close/
|
||||||
|
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||||
|
// it's essentially never contended in practice.
|
||||||
|
func reserveLoopbackPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
addr := l.Addr().String()
|
||||||
|
require.NoError(t, l.Close())
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForServer(t *testing.T, addr string) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("server did not start listening on %s in time", addr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
|
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
c.dispatch(w, r, math.MaxInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||||
|
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||||
|
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
|
if entry.Priority > maxPriority {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !c.isHandlerMatch(qname, entry) {
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
cw.response.Len(), meta, time.Since(startTime))
|
cw.response.Len(), meta, time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||||
|
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||||
|
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||||
|
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||||
|
// (bounded by the invoked handler's internal timeout).
|
||||||
|
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty question")
|
||||||
|
}
|
||||||
|
|
||||||
|
base := &internalResponseWriter{}
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
c.dispatch(base, r, maxPriority)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||||
|
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||||
|
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||||
|
}
|
||||||
|
return base.response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||||
|
// priority ≤ maxPriority.
|
||||||
|
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, h := range c.handlers {
|
||||||
|
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
switch {
|
switch {
|
||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||||
|
type internalResponseWriter struct {
|
||||||
|
response *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||||
|
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
|
||||||
|
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||||
|
// still surface their answer to ResolveInternal.
|
||||||
|
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(p); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
w.response = msg
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *internalResponseWriter) Close() error { return nil }
|
||||||
|
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||||
|
|
||||||
|
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||||
|
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||||
|
// no-op: in-process queries carry no TSIG state.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack is part of dns.ResponseWriter.
|
||||||
|
func (w *internalResponseWriter) Hijack() {
|
||||||
|
// no-op: in-process queries have no underlying connection to hand off.
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||||
|
// which handler ResolveInternal dispatches to.
|
||||||
|
type answeringHandler struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Answer = []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP(h.ip).To4(),
|
||||||
|
}}
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *answeringHandler) String() string { return h.name }
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||||
|
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||||
|
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.Equal(t, 1, len(resp.Answer))
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||||
|
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||||
|
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||||
|
type rawWriteHandler struct {
|
||||||
|
ip string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Answer = []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP(h.ip).To4(),
|
||||||
|
}}
|
||||||
|
packed, err := resp.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = w.Write(packed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||||
|
type hangingHandler struct {
|
||||||
|
block chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
<-h.block
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
h := &hangingHandler{block: make(chan struct{})}
|
||||||
|
defer close(h.block)
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||||
|
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||||
|
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||||
|
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||||
|
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||||
|
|
||||||
|
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||||
|
}
|
||||||
|
|||||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, reason, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s", osManager)
|
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, error) {
|
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||||
|
resolved := isSystemdResolvedRunning()
|
||||||
|
nss := isLibnssResolveUsed()
|
||||||
|
stub := checkStub()
|
||||||
|
|
||||||
|
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||||
|
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||||
|
// that go through nss-resolve, and in foreign mode they can loop back
|
||||||
|
// through resolved as an upstream.
|
||||||
|
if resolved && (nss || stub) {
|
||||||
|
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", err
|
||||||
|
}
|
||||||
|
if reason != "" {
|
||||||
|
return mgr, reason, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||||
|
if len(rejected) > 0 {
|
||||||
|
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||||
|
}
|
||||||
|
return fileManager, fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||||
|
// matching manager. If reason is empty the caller should pick file mode and
|
||||||
|
// use rejected for diagnostics.
|
||||||
|
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := file.Close(); err != nil {
|
if cerr := file.Close(); cerr != nil {
|
||||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var rejected []string
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text[0] != '#' {
|
if text[0] != '#' {
|
||||||
return fileManager, nil
|
break
|
||||||
}
|
}
|
||||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||||
return netbirdManager, nil
|
return mgr, reason, nil, nil
|
||||||
}
|
} else if rej != "" {
|
||||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
rejected = append(rejected, rej)
|
||||||
return networkManager, nil
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
|
||||||
if checkStub() {
|
|
||||||
return systemdManager, nil
|
|
||||||
} else {
|
|
||||||
return fileManager, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "resolvconf") {
|
|
||||||
if isSystemdResolveConfMode() {
|
|
||||||
return systemdManager, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolvConfManager, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
return 0, fmt.Errorf("scan: %w", err)
|
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||||
|
}
|
||||||
|
return 0, "", rejected, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return fileManager, nil
|
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||||
|
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||||
|
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||||
|
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||||
|
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "NetworkManager") {
|
||||||
|
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||||
|
}
|
||||||
|
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "resolvconf") {
|
||||||
|
if isSystemdResolveConfMode() {
|
||||||
|
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||||
|
}
|
||||||
|
return resolvConfManager, "resolvconf header detected", ""
|
||||||
|
}
|
||||||
|
return 0, "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||||
|
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||||
|
// into file mode while resolved is active.
|
||||||
func checkStub() bool {
|
func checkStub() bool {
|
||||||
rConf, err := parseDefaultResolvConf()
|
rConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse resolv conf: %s", err)
|
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||||
|
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||||
|
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||||
|
func isLibnssResolveUsed() bool {
|
||||||
|
bs, err := os.ReadFile(nsswitchConfPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return parseNsswitchResolveAhead(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseNsswitchResolveAhead(data []byte) bool {
|
||||||
|
for _, line := range strings.Split(string(data), "\n") {
|
||||||
|
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||||
|
line = line[:i]
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, module := range fields[1:] {
|
||||||
|
switch module {
|
||||||
|
case "dns":
|
||||||
|
return false
|
||||||
|
case "resolve":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "resolve before dns with action token",
|
||||||
|
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dns before resolve",
|
||||||
|
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "debian default with only dns",
|
||||||
|
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "neither resolve nor dns",
|
||||||
|
in: "hosts: files myhostname\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no hosts line",
|
||||||
|
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
in: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comments and blank lines ignored",
|
||||||
|
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing inline comment",
|
||||||
|
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hosts token must be the first field",
|
||||||
|
in: " hosts: resolve dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other db line mentioning resolve is ignored",
|
||||||
|
in: "networks: resolve\nhosts: dns\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only resolve, no dns",
|
||||||
|
in: "hosts: files resolve\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||||
|
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,40 +2,83 @@ package mgmt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dnsTimeout = 5 * time.Second
|
const (
|
||||||
|
dnsTimeout = 5 * time.Second
|
||||||
|
defaultTTL = 300 * time.Second
|
||||||
|
refreshBackoff = 30 * time.Second
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains
|
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||||
|
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||||
|
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||||
|
// system resolver.
|
||||||
|
type ChainResolver interface {
|
||||||
|
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||||
|
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||||
|
// records and cachedAt are set at construction and treated as immutable;
|
||||||
|
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||||
|
// Resolver.mutex.
|
||||||
|
type cachedRecord struct {
|
||||||
|
records []dns.RR
|
||||||
|
cachedAt time.Time
|
||||||
|
lastFailedRefresh time.Time
|
||||||
|
consecFailures int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolver caches critical NetBird infrastructure domains.
|
||||||
|
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question][]dns.RR
|
records map[dns.Question]*cachedRecord
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
|
||||||
|
|
||||||
type ipsResponse struct {
|
chain ChainResolver
|
||||||
ips []netip.Addr
|
chainMaxPriority int
|
||||||
err error
|
refreshGroup singleflight.Group
|
||||||
|
|
||||||
|
// refreshing tracks questions whose refresh is running via the OS
|
||||||
|
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||||
|
// the OS resolver routed the recursive query back to us (loop). Only
|
||||||
|
// the OS path arms this so chain-path refreshes don't produce false
|
||||||
|
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||||
|
// throttle the warning log.
|
||||||
|
refreshing map[dns.Question]*atomic.Bool
|
||||||
|
|
||||||
|
cacheTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question][]dns.RR),
|
records: make(map[dns.Question]*cachedRecord),
|
||||||
|
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||||
|
cacheTTL: resolveCacheTTL(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
|
|||||||
return "MgmtCacheResolver"
|
return "MgmtCacheResolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS implements dns.Handler interface.
|
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||||
|
// maxPriority caps which handlers may answer refresh queries (typically
|
||||||
|
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||||
|
// mgmt/route/local handlers are skipped).
|
||||||
|
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.chain = chain
|
||||||
|
m.chainMaxPriority = maxPriority
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||||
|
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
records, found := m.records[question]
|
cached, found := m.records[question]
|
||||||
|
inflight := m.refreshing[question]
|
||||||
|
var shouldRefresh bool
|
||||||
|
if found {
|
||||||
|
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||||
|
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||||
|
shouldRefresh = stale && !inBackoff
|
||||||
|
}
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||||
|
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||||
|
question.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||||
|
// this question; singleflight would dedup anyway but skipping avoids
|
||||||
|
// a parked goroutine per stale hit under bursty load.
|
||||||
|
if shouldRefresh && inflight == nil {
|
||||||
|
m.scheduleRefresh(question, cached)
|
||||||
|
}
|
||||||
|
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetReply(r)
|
resp.SetReply(r)
|
||||||
resp.Authoritative = false
|
resp.Authoritative = false
|
||||||
resp.RecursionAvailable = true
|
resp.RecursionAvailable = true
|
||||||
|
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||||
resp.Answer = append(resp.Answer, records...)
|
|
||||||
|
|
||||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||||
|
|
||||||
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDomain manually adds a domain to cache by resolving it.
|
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||||
|
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||||
|
// entry for that qtype.
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||||
if err != nil {
|
|
||||||
return err
|
if errA != nil && errAAAA != nil {
|
||||||
|
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||||
}
|
}
|
||||||
|
|
||||||
var aRecords, aaaaRecords []dns.RR
|
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||||
for _, ip := range ips {
|
if err := errors.Join(errA, errAAAA); err != nil {
|
||||||
if ip.Is4() {
|
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||||
rr := &dns.A{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: dnsName,
|
|
||||||
Rrtype: dns.TypeA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 300,
|
|
||||||
},
|
|
||||||
A: ip.AsSlice(),
|
|
||||||
}
|
|
||||||
aRecords = append(aRecords, rr)
|
|
||||||
} else if ip.Is6() {
|
|
||||||
rr := &dns.AAAA{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: dnsName,
|
|
||||||
Rrtype: dns.TypeAAAA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 300,
|
|
||||||
},
|
|
||||||
AAAA: ip.AsSlice(),
|
|
||||||
}
|
|
||||||
aaaaRecords = append(aaaaRecords, rr)
|
|
||||||
}
|
}
|
||||||
|
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if len(aRecords) > 0 {
|
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||||
aQuestion := dns.Question{
|
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||||
Name: dnsName,
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
m.records[aQuestion] = aRecords
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(aaaaRecords) > 0 {
|
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||||
aaaaQuestion := dns.Question{
|
|
||||||
Name: dnsName,
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
m.records[aaaaQuestion] = aaaaRecords
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
// untouched on error. Caller holds m.mutex.
|
||||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||||
resultChan := make(chan *ipsResponse, 1)
|
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||||
|
switch {
|
||||||
go func() {
|
case len(records) > 0:
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||||
resultChan <- &ipsResponse{
|
case err == nil:
|
||||||
err: err,
|
delete(m.records, q)
|
||||||
ips: ips,
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
var resp *ipsResponse
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
|
||||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
|
||||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case resp = <-resultChan:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.err != nil {
|
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||||
|
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||||
|
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||||
|
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||||
|
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||||
|
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||||
|
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||||
|
return nil, m.refreshQuestion(question, expected)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return resp.ips, nil
|
|
||||||
|
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||||
|
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||||
|
// a resolver loop by spotting a query for this same question arriving on us.
|
||||||
|
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||||
|
// if m.records[question] still points at it.
|
||||||
|
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||||
|
if err != nil {
|
||||||
|
m.markRefreshFailed(question, expected)
|
||||||
|
return fmt.Errorf("parse domain: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
records, err := m.lookupRecords(ctx, d, question)
|
||||||
|
if err != nil {
|
||||||
|
fails := m.markRefreshFailed(question, expected)
|
||||||
|
logf := log.Warnf
|
||||||
|
if fails == 0 || fails > 1 {
|
||||||
|
logf = log.Debugf
|
||||||
|
}
|
||||||
|
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||||
|
if len(records) == 0 {
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.records[question] == expected {
|
||||||
|
delete(m.records, question)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.records[question] != expected {
|
||||||
|
m.mutex.Unlock()
|
||||||
|
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.refreshing[question] = &atomic.Bool{}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
delete(m.refreshing, question)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||||
|
// count so callers can downgrade subsequent failure logs to debug.
|
||||||
|
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
c, ok := m.records[question]
|
||||||
|
if !ok || c != expected {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
c.lastFailedRefresh = time.Now()
|
||||||
|
c.consecFailures++
|
||||||
|
return c.consecFailures
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||||
|
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||||
|
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
chain := m.chain
|
||||||
|
maxPriority := m.chainMaxPriority
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||||
|
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||||
|
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||||
|
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||||
|
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||||
|
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||||
|
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||||
|
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||||
|
// spot the OS resolver routing the recursive query back to us.
|
||||||
|
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
chain := m.chain
|
||||||
|
maxPriority := m.chainMaxPriority
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||||
|
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: drop once every supported OS registers a fallback resolver.
|
||||||
|
m.markRefreshing(q)
|
||||||
|
defer m.clearRefreshing(q)
|
||||||
|
|
||||||
|
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||||
|
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||||
|
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||||
|
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||||
|
msg := &dns.Msg{}
|
||||||
|
msg.SetQuestion(dnsName, qtype)
|
||||||
|
msg.RecursionDesired = true
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||||
|
}
|
||||||
|
if resp.Rcode != dns.RcodeSuccess {
|
||||||
|
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := uint32(m.cacheTTL.Seconds())
|
||||||
|
owners := cnameOwners(dnsName, resp.Answer)
|
||||||
|
var filtered []dns.RR
|
||||||
|
for _, rr := range resp.Answer {
|
||||||
|
h := rr.Header()
|
||||||
|
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||||
|
filtered = append(filtered, cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||||
|
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||||
|
// returns (nil, nil).
|
||||||
|
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||||
|
network := resutil.NetworkForQtype(qtype)
|
||||||
|
if network == "" {
|
||||||
|
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||||
|
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||||
|
|
||||||
|
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||||
|
if result.Rcode == dns.RcodeSuccess {
|
||||||
|
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||||
|
// so downstream resolvers don't cache an answer for longer than we will.
|
||||||
|
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||||
|
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return uint32((remaining + time.Second - 1) / time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
aQuestion := dns.Question{
|
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
Name: dnsName,
|
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||||
Qtype: dns.TypeA,
|
delete(m.records, qA)
|
||||||
Qclass: dns.ClassINET,
|
delete(m.records, qAAAA)
|
||||||
}
|
delete(m.refreshing, qA)
|
||||||
delete(m.records, aQuestion)
|
delete(m.refreshing, qAAAA)
|
||||||
|
|
||||||
aaaaQuestion := dns.Question{
|
|
||||||
Name: dnsName,
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
delete(m.records, aaaaQuestion)
|
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||||
|
// A/AAAA records return nil.
|
||||||
|
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||||
|
switch r := rr.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
cp := *r
|
||||||
|
cp.Hdr.Name = owner
|
||||||
|
cp.Hdr.Ttl = ttl
|
||||||
|
cp.A = slices.Clone(r.A)
|
||||||
|
return &cp
|
||||||
|
case *dns.AAAA:
|
||||||
|
cp := *r
|
||||||
|
cp.Hdr.Name = owner
|
||||||
|
cp.Hdr.Ttl = ttl
|
||||||
|
cp.AAAA = slices.Clone(r.AAAA)
|
||||||
|
return &cp
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||||
|
// stamping ttl so the response shares no memory with the cached slice.
|
||||||
|
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||||
|
out := make([]dns.RR, 0, len(records))
|
||||||
|
for _, rr := range records {
|
||||||
|
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||||
|
out = append(out, cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||||
|
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||||
|
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||||
|
owners := map[string]bool{dnsName: true}
|
||||||
|
for {
|
||||||
|
added := false
|
||||||
|
for _, rr := range answer {
|
||||||
|
cname, ok := rr.(*dns.CNAME)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||||
|
if !owners[name] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||||
|
if !owners[target] {
|
||||||
|
owners[target] = true
|
||||||
|
added = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !added {
|
||||||
|
return owners
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||||
|
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||||
|
func resolveCacheTTL() time.Duration {
|
||||||
|
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultTTL
|
||||||
|
}
|
||||||
|
|||||||
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeChain struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls map[string]int
|
||||||
|
answers map[string][]dns.RR
|
||||||
|
err error
|
||||||
|
hasRoot bool
|
||||||
|
onLookup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeChain() *fakeChain {
|
||||||
|
return &fakeChain{
|
||||||
|
calls: map[string]int{},
|
||||||
|
answers: map[string][]dns.RR{},
|
||||||
|
hasRoot: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.hasRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
q := msg.Question[0]
|
||||||
|
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||||
|
f.calls[key]++
|
||||||
|
answers := f.answers[key]
|
||||||
|
err := f.err
|
||||||
|
onLookup := f.onLookup
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
if onLookup != nil {
|
||||||
|
onLookup()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(msg)
|
||||||
|
resp.Answer = answers
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
key := name + "|" + dns.TypeToString[qtype]
|
||||||
|
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||||
|
switch qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||||
|
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(d)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if fn() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("condition not met within %s", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||||
|
t.Helper()
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.SetQuestion(name, dns.TypeA)
|
||||||
|
w := &test.MockResponseWriter{}
|
||||||
|
r.ServeDNS(w, msg)
|
||||||
|
return w.GetLastResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||||
|
t.Helper()
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
require.True(t, ok, "expected A record")
|
||||||
|
return a.A.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||||
|
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||||
|
// trigger a background refresh, the longer one must not. Proves that the
|
||||||
|
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||||
|
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||||
|
|
||||||
|
newRec := func() *cachedRecord {
|
||||||
|
return &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: cachedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
|
||||||
|
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
r.cacheTTL = 10 * time.Millisecond
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
r.records[q] = newRec()
|
||||||
|
|
||||||
|
resp := queryA(t, r, q.Name)
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||||
|
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
r.cacheTTL = time.Hour
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
r.records[q] = newRec()
|
||||||
|
|
||||||
|
resp := queryA(t, r, q.Name)
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(), // fresh
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||||
|
}
|
||||||
|
|
||||||
|
// First query: serves stale immediately.
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||||
|
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||||
|
})
|
||||||
|
|
||||||
|
// Next query should now return the refreshed IP.
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
|
||||||
|
var inflight atomic.Int32
|
||||||
|
var maxInflight atomic.Int32
|
||||||
|
chain.onLookup = func() {
|
||||||
|
cur := inflight.Add(1)
|
||||||
|
defer inflight.Add(-1)
|
||||||
|
for {
|
||||||
|
prev := maxInflight.Load()
|
||||||
|
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||||
|
}
|
||||||
|
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
waitFor(t, 2*time.Second, func() bool {
|
||||||
|
return inflight.Load() == 0
|
||||||
|
})
|
||||||
|
|
||||||
|
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||||
|
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||||
|
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.err = errors.New("boom")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||||
|
}
|
||||||
|
|
||||||
|
// First stale hit triggers a refresh attempt that fails.
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||||
|
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||||
|
})
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
c, ok := r.records[q]
|
||||||
|
return ok && !c.lastFailedRefresh.IsZero()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.hasRoot = false
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
// With hasRoot=false the chain must not be consulted. Use a short
|
||||||
|
// deadline so the OS fallback returns quickly without waiting on a
|
||||||
|
// real network call in CI.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||||
|
|
||||||
|
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||||
|
"chain must not be used when no root handler is registered at the bound priority")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||||
|
// ServeDNS being invoked for a question while a refresh for that question
|
||||||
|
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||||
|
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||||
|
r := NewResolver()
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate an inflight refresh.
|
||||||
|
r.markRefreshing(q)
|
||||||
|
defer r.clearRefreshing(q)
|
||||||
|
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
inflight := r.refreshing[q]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
require.NotNil(t, inflight)
|
||||||
|
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.markRefreshing(q)
|
||||||
|
defer r.clearRefreshing(q)
|
||||||
|
|
||||||
|
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||||
|
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||||
|
for range 5 {
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
inflight := r.refreshing[q]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.True(t, inflight.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, ok := r.refreshing[q]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||||
|
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||||
|
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||||
|
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
|
|||||||
assert.False(t, resolver.MatchSubdomains())
|
assert.False(t, resolver.MatchSubdomains())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResolveCacheTTL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{"unset falls back to default", "", defaultTTL},
|
||||||
|
{"valid duration", "45s", 45 * time.Second},
|
||||||
|
{"valid minutes", "2m", 2 * time.Minute},
|
||||||
|
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||||
|
{"zero falls back to default", "0s", defaultTTL},
|
||||||
|
{"negative falls back to default", "-5s", defaultTTL},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||||
|
got := resolveCacheTTL()
|
||||||
|
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||||
|
t.Setenv(envMgmtCacheTTL, "7s")
|
||||||
|
r := NewResolver()
|
||||||
|
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ResponseTTL(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cacheTTL time.Duration
|
||||||
|
cachedAt time.Time
|
||||||
|
wantMin uint32
|
||||||
|
wantMax uint32
|
||||||
|
}{
|
||||||
|
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||||
|
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||||
|
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||||
|
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||||
|
got := r.responseTTL(tc.cachedAt)
|
||||||
|
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||||
|
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ func newDefaultServer(
|
|||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
mgmtCacheResolver := mgmt.NewResolver()
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
|
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -604,6 +605,8 @@ func (e *Engine) createFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
firewalld.SetParentContext(e.ctx)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
|
|
||||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||||
|
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||||
|
// can outlive the test and call t.Logf after teardown, panicking.
|
||||||
|
receiveDone := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
select {
|
||||||
|
case <-receiveDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("Receive goroutine did not exit after Close")
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := client.Close()
|
err := client.Close()
|
||||||
assert.NoError(t, err, "failed to close flow")
|
assert.NoError(t, err, "failed to close flow")
|
||||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
receivedAfterReconnect := make(chan struct{})
|
receivedAfterReconnect := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(receiveDone)
|
||||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
|
|||||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||||
echo "Registering CrowdSec bouncer..."
|
echo "Registering CrowdSec bouncer..."
|
||||||
local cs_retries=0
|
local cs_retries=0
|
||||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
||||||
cs_retries=$((cs_retries + 1))
|
cs_retries=$((cs_retries + 1))
|
||||||
if [[ $cs_retries -ge 30 ]]; then
|
if [[ $cs_retries -ge 30 ]]; then
|
||||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
@@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
switch authType {
|
switch authType {
|
||||||
case "bearer":
|
case "bearer":
|
||||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
h.ServeHTTP(w, request)
|
|
||||||
case "token":
|
case "token":
|
||||||
request, err := m.checkPATFromRequest(r, authHeader)
|
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||||
// Check if it's a status error, otherwise default to Unauthorized
|
// Check if it's a status error, otherwise default to Unauthorized
|
||||||
if _, ok := status.FromError(err); !ok {
|
if _, ok := status.FromError(err); !ok {
|
||||||
@@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.ServeHTTP(w, request)
|
h.ServeHTTP(w, r)
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||||
return
|
return
|
||||||
@@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckJWTFromRequest checks if the JWT is valid
|
// CheckJWTFromRequest checks if the JWT is valid
|
||||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
// If an error occurs, call the error handler and return an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
@@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if userAuth.AccountId != accountId {
|
if userAuth.AccountId != accountId {
|
||||||
@@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
|
|
||||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||||
@@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
// propagates ctx change to upstream middleware
|
||||||
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.patUsageTracker != nil {
|
if m.patUsageTracker != nil {
|
||||||
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
|
|
||||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||||
if !m.rateLimiter.Allow(token) {
|
if !m.rateLimiter.Allow(token) {
|
||||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("invalid Token: %w", err)
|
return fmt.Errorf("invalid Token: %w", err)
|
||||||
}
|
}
|
||||||
if time.Now().After(pat.GetExpirationDate()) {
|
if time.Now().After(pat.GetExpirationDate()) {
|
||||||
return r, fmt.Errorf("token expired")
|
return fmt.Errorf("token expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth := auth.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
@@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
// propagates ctx change to upstream middleware
|
||||||
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTerraformRequest(r *http.Request) bool {
|
func isTerraformRequest(r *http.Request) bool {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -46,33 +47,52 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
var isUpdate = policy.ID != ""
|
var isUpdate = policy.ID != ""
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var action = activity.PolicyAdded
|
var action = activity.PolicyAdded
|
||||||
|
var unchanged bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
saveFunc := transaction.CreatePolicy
|
|
||||||
if isUpdate {
|
if isUpdate {
|
||||||
action = activity.PolicyUpdated
|
if policy.Equal(existingPolicy) {
|
||||||
saveFunc = transaction.SavePolicy
|
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||||
|
unchanged = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = saveFunc(ctx, policy); err != nil {
|
action = activity.PolicyUpdated
|
||||||
|
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if unchanged {
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -138,14 +158,18 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
|||||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||||
if isUpdate {
|
for _, rule := range policy.Rules {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
if err != nil {
|
return true, nil
|
||||||
return false, err
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||||
|
}
|
||||||
|
|
||||||
|
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||||
if !policy.Enabled && !existingPolicy.Enabled {
|
if !policy.Enabled && !existingPolicy.Enabled {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@@ -164,7 +188,6 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
|||||||
if hasPeers {
|
if hasPeers {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
|||||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||||
}
|
}
|
||||||
|
|
||||||
// validatePolicy validates the policy and its rules.
|
// validatePolicy validates the policy and its rules. For updates it returns
|
||||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
// the existing policy loaded from the store so callers can avoid a second read.
|
||||||
|
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||||
|
var existingPolicy *types.Policy
|
||||||
if policy.ID != "" {
|
if policy.ID != "" {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
var err error
|
||||||
|
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Refactor to support multiple rules per policy
|
// TODO: Refactor to support multiple rules per policy
|
||||||
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
|
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
|
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, rule := range policy.Rules {
|
for i, rule := range policy.Rules {
|
||||||
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return existingPolicy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||||
|
|||||||
@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
h.ServeHTTP(w, r.WithContext(ctx))
|
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
|
||||||
|
req := r.WithContext(ctx)
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
close(handlerDone)
|
close(handlerDone)
|
||||||
|
|
||||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
ctx = req.Context()
|
||||||
if err == nil {
|
|
||||||
if userAuth.AccountId != "" {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
|
||||||
}
|
|
||||||
if userAuth.UserId != "" {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.Status() > 399 {
|
if w.Status() > 399 {
|
||||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||||
|
|||||||
@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Policy) Equal(other *Policy) bool {
|
||||||
|
if p == nil || other == nil {
|
||||||
|
return p == other
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.ID != other.ID ||
|
||||||
|
p.AccountID != other.AccountID ||
|
||||||
|
p.Name != other.Name ||
|
||||||
|
p.Description != other.Description ||
|
||||||
|
p.Enabled != other.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.Rules) != len(other.Rules) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
otherRules := make(map[string]*PolicyRule, len(other.Rules))
|
||||||
|
for _, r := range other.Rules {
|
||||||
|
otherRules[r.ID] = r
|
||||||
|
}
|
||||||
|
for _, r := range p.Rules {
|
||||||
|
otherRule, ok := otherRules[r.ID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !r.Equal(otherRule) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event meta related to this policy
|
// EventMeta returns activity event meta related to this policy
|
||||||
func (p *Policy) EventMeta() map[string]any {
|
func (p *Policy) EventMeta() map[string]any {
|
||||||
return map[string]any{"name": p.Name}
|
return map[string]any{"name": p.Name}
|
||||||
|
|||||||
193
management/server/types/policy_test.go
Normal file
193
management/server/types/policy_test.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentRules(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
{ID: "r2", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc3"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
|
||||||
|
base := Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Description: "desc",
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
other := base
|
||||||
|
other.Name = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Enabled = false
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Description = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_NilCases(t *testing.T) {
|
||||||
|
var a *Policy
|
||||||
|
var b *Policy
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
|
||||||
|
a = &Policy{ID: "pol1"}
|
||||||
|
assert.False(t, a.Equal(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r2", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_FullScenario(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "Web Access",
|
||||||
|
Description: "Allow web access",
|
||||||
|
Enabled: true,
|
||||||
|
SourcePostureChecks: []string{"pc2", "pc1"},
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "r1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "HTTP",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"g2", "g1"},
|
||||||
|
Destinations: []string{"g4", "g3"},
|
||||||
|
Ports: []string{"443", "80", "8080"},
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "Web Access",
|
||||||
|
Description: "Allow web access",
|
||||||
|
Enabled: true,
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "r1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "HTTP",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"g1", "g2"},
|
||||||
|
Destinations: []string{"g3", "g4"},
|
||||||
|
Ports: []string{"80", "8080", "443"},
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
|||||||
}
|
}
|
||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||||
|
if pm == nil || other == nil {
|
||||||
|
return pm == other
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.ID != other.ID ||
|
||||||
|
pm.PolicyID != other.PolicyID ||
|
||||||
|
pm.Name != other.Name ||
|
||||||
|
pm.Description != other.Description ||
|
||||||
|
pm.Enabled != other.Enabled ||
|
||||||
|
pm.Action != other.Action ||
|
||||||
|
pm.Bidirectional != other.Bidirectional ||
|
||||||
|
pm.Protocol != other.Protocol ||
|
||||||
|
pm.SourceResource != other.SourceResource ||
|
||||||
|
pm.DestinationResource != other.DestinationResource ||
|
||||||
|
pm.AuthorizedUser != other.AuthorizedUser {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSlicesEqualUnordered(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
sorted1 := make([]string, len(a))
|
||||||
|
sorted2 := make([]string, len(b))
|
||||||
|
copy(sorted1, a)
|
||||||
|
copy(sorted2, b)
|
||||||
|
slices.Sort(sorted1)
|
||||||
|
slices.Sort(sorted2)
|
||||||
|
return slices.Equal(sorted1, sorted2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cmp := func(x, y RulePortRange) int {
|
||||||
|
if x.Start != y.Start {
|
||||||
|
if x.Start < y.Start {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if x.End != y.End {
|
||||||
|
if x.End < y.End {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
sorted1 := make([]RulePortRange, len(a))
|
||||||
|
sorted2 := make([]RulePortRange, len(b))
|
||||||
|
copy(sorted1, a)
|
||||||
|
copy(sorted2, b)
|
||||||
|
slices.SortFunc(sorted1, cmp)
|
||||||
|
slices.SortFunc(sorted2, cmp)
|
||||||
|
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
|
||||||
|
return x.Start == y.Start && x.End == y.End
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func authorizedGroupsEqual(a, b map[string][]string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, va := range a {
|
||||||
|
vb, ok := b[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(va, vb) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
194
management/server/types/policyrule_test.go
Normal file
194
management/server/types/policyrule_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "80", "22"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"22", "443", "80"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "80"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "22"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g2", "g3"},
|
||||||
|
Destinations: []string{"g4", "g5"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g3", "g1", "g2"},
|
||||||
|
Destinations: []string{"g5", "g4"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g2"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g3"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 443},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u1", "u2", "u3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u3", "u1", "u2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g2": {"u1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
|
||||||
|
base := PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "test",
|
||||||
|
Description: "desc",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
other := base
|
||||||
|
other.Name = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Enabled = false
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Action = PolicyTrafficActionDrop
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Protocol = PolicyRuleProtocolUDP
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_NilCases(t *testing.T) {
|
||||||
|
var a *PolicyRule
|
||||||
|
var b *PolicyRule
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
|
||||||
|
a = &PolicyRule{ID: "rule1"}
|
||||||
|
assert.False(t, a.Equal(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{},
|
||||||
|
Sources: nil,
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: nil,
|
||||||
|
Sources: []string{},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
@@ -30,6 +30,8 @@ import (
|
|||||||
|
|
||||||
const ConnectTimeout = 10 * time.Second
|
const ConnectTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
const healthCheckTimeout = 5 * time.Second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||||
// for the management client connection. Value is in bytes.
|
// for the management client connection. Value is in bytes.
|
||||||
@@ -532,7 +534,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
|||||||
case connectivity.Ready:
|
case connectivity.Ready:
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const healthCheckTimeout = 5 * time.Second
|
||||||
|
|
||||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||||
type ConnStateNotifier interface {
|
type ConnStateNotifier interface {
|
||||||
MarkSignalDisconnected(error)
|
MarkSignalDisconnected(error)
|
||||||
@@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
|||||||
case connectivity.Ready:
|
case connectivity.Ready:
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
|
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
|
||||||
Key: c.key.PublicKey().String(),
|
Key: c.key.PublicKey().String(),
|
||||||
|
|||||||
Reference in New Issue
Block a user