mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 10:16:38 +00:00
Compare commits
8 Commits
fix/manage
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c |
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||||
|
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||||
|
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||||
|
// versions, which returns EPERM to any other process that tries to insert
|
||||||
|
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||||
|
// itself add the accept rules to its own chains by trusting the interface.
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||||
|
// should bypass firewalld filtering.
|
||||||
|
const TrustedZone = "trusted"
|
||||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/godbus/dbus/v5"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dbusDest = "org.fedoraproject.FirewallD1"
|
||||||
|
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||||
|
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||||
|
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||||
|
|
||||||
|
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||||
|
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||||
|
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||||
|
errNotEnabled = "NOT_ENABLED"
|
||||||
|
|
||||||
|
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||||
|
// A fresh context is created for each call so a slow DBus probe can't
|
||||||
|
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||||
|
callTimeout = 3 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||||
|
|
||||||
|
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||||
|
// Info level only for the first successful add per process; repeat adds
|
||||||
|
// from other init paths are quieter.
|
||||||
|
trustLogOnce sync.Once
|
||||||
|
|
||||||
|
parentCtxMu sync.RWMutex
|
||||||
|
parentCtx context.Context = context.Background()
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetParentContext installs a parent context whose cancellation aborts any
|
||||||
|
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||||
|
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||||
|
// during engine shutdown when the engine context is already cancelled.
|
||||||
|
func SetParentContext(ctx context.Context) {
|
||||||
|
parentCtxMu.Lock()
|
||||||
|
parentCtx = ctx
|
||||||
|
parentCtxMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getParentContext() context.Context {
|
||||||
|
parentCtxMu.RLock()
|
||||||
|
defer parentCtxMu.RUnlock()
|
||||||
|
return parentCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||||
|
// running. It is idempotent and best-effort: errors are returned so callers
|
||||||
|
// can log, but a non-running firewalld is not an error. Only the first
|
||||||
|
// successful call per process logs at Info. Respects the parent context set
|
||||||
|
// via SetParentContext so startup-time cancellation unblocks it.
|
||||||
|
func TrustInterface(iface string) error {
|
||||||
|
parent := getParentContext()
|
||||||
|
if !isRunning(parent) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := addTrusted(parent, iface); err != nil {
|
||||||
|
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||||
|
}
|
||||||
|
trustLogOnce.Do(func() {
|
||||||
|
log.Infof("added %s to firewalld trusted zone", iface)
|
||||||
|
})
|
||||||
|
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||||
|
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||||
|
// during shutdown after the engine context has been cancelled.
|
||||||
|
func UntrustInterface(iface string) error {
|
||||||
|
if !isRunning(context.Background()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||||
|
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(parent, callTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunning(parent context.Context) bool {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
ok, err := isRunningDBus(ctx)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return isRunningCLI(ctx)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func addTrusted(parent context.Context, iface string) error {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
err := addDBus(ctx, iface)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errDBusUnavailable) {
|
||||||
|
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return addCLI(ctx, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeTrusted(parent context.Context, iface string) error {
|
||||||
|
ctx, cancel := newCallContext(parent)
|
||||||
|
err := removeDBus(ctx, iface)
|
||||||
|
cancel()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, errDBusUnavailable) {
|
||||||
|
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||||
|
}
|
||||||
|
ctx, cancel = newCallContext(parent)
|
||||||
|
defer cancel()
|
||||||
|
return removeCLI(ctx, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
var zone string
|
||||||
|
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||||
|
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRunningCLI(ctx context.Context) bool {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDBus(ctx context.Context, iface string) error {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||||
|
if call.Err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||||
|
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||||
|
if move.Err != nil {
|
||||||
|
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeDBus(ctx context.Context, iface string) error {
|
||||||
|
conn, err := dbus.SystemBus()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||||
|
}
|
||||||
|
obj := conn.Object(dbusDest, dbusPath)
|
||||||
|
|
||||||
|
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||||
|
if call.Err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addCLI(ctx context.Context, iface string) error {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --change-interface (no --permanent) binds the interface for the
|
||||||
|
// current runtime only; we do not want membership to persist across
|
||||||
|
// reboots because netbird re-asserts it on every startup.
|
||||||
|
out, err := exec.CommandContext(ctx,
|
||||||
|
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||||
|
).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeCLI(ctx context.Context, iface string) error {
|
||||||
|
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||||
|
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.CommandContext(ctx,
|
||||||
|
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||||
|
).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
msg := strings.TrimSpace(string(out))
|
||||||
|
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dbusErrContains(err error, code string) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var de dbus.Error
|
||||||
|
if errors.As(err, &de) {
|
||||||
|
for _, b := range de.Body {
|
||||||
|
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Contains(err.Error(), code)
|
||||||
|
}
|
||||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/godbus/dbus/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDBusErrContains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
code string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil error", nil, errZoneAlreadySet, false},
|
||||||
|
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||||
|
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||||
|
{
|
||||||
|
"dbus.Error body match",
|
||||||
|
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||||
|
errZoneAlreadySet,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dbus.Error body miss",
|
||||||
|
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||||
|
errAlreadyEnabled,
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dbus.Error non-string body falls back to Error()",
|
||||||
|
dbus.Error{Name: "x", Body: []any{123}},
|
||||||
|
"x",
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := dbusErrContains(tc.err, tc.code)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package firewalld
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func SetParentContext(context.Context) {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func TrustInterface(string) error {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||||
|
// runs on Linux.
|
||||||
|
func UntrustInterface(string) error {
|
||||||
|
// intentionally empty: firewalld is a Linux-only daemon
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||||
|
// interface in firewalld's trusted zone without a corresponding Close.
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
go func() {
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||||
|
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
// attempt to delete state only if all other operations succeeded
|
// attempt to delete state only if all other operations succeeded
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||||
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
@@ -40,6 +41,8 @@ const (
|
|||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
|
firewalldTableName = "firewalld"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
userDataAcceptInputRule = "inputaccept"
|
userDataAcceptInputRule = "inputaccept"
|
||||||
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatPreroutingRules(); err != nil {
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
}
|
}
|
||||||
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
|
|||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
log.Errorf("failed to refresh rules: %s", err)
|
log.Errorf("failed to refresh rules: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||||
|
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||||
|
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||||
|
if chain.Table.Name == firewalldTableName {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Skip all iptables-managed tables in the ip family
|
// Skip all iptables-managed tables in the ip family
|
||||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,6 +3,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AllowNetbird()
|
return m.nativeFirewall.AllowNetbird()
|
||||||
}
|
}
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
|
Name() string
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
|
|||||||
@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
|||||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
|
NameFunc func() string
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) Name() string {
|
||||||
|
if i.NameFunc == nil {
|
||||||
|
return "wgtest"
|
||||||
|
}
|
||||||
|
return i.NameFunc()
|
||||||
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
if i.GetWGDeviceFunc == nil {
|
if i.GetWGDeviceFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
|||||||
ipv6Count++
|
ipv6Count++
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
// routing-correctness checks above are the real assertions; the counts
|
||||||
|
// are a sanity bound to catch a totally silent path.
|
||||||
|
minDelivered := packetsPerFamily * 80 / 100
|
||||||
|
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||||
|
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ package debug
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Skip("Skipping upload test on docker ci")
|
t.Skip("Skipping upload test on docker ci")
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
testDir := t.TempDir()
|
||||||
testURL := "http://localhost:8080"
|
addr := reserveLoopbackPort(t)
|
||||||
|
testURL := "http://" + addr
|
||||||
t.Setenv("SERVER_URL", testURL)
|
t.Setenv("SERVER_URL", testURL)
|
||||||
|
t.Setenv("SERVER_ADDRESS", addr)
|
||||||
t.Setenv("STORE_DIR", testDir)
|
t.Setenv("STORE_DIR", testDir)
|
||||||
srv := server.NewServer()
|
srv := server.NewServer()
|
||||||
go func() {
|
go func() {
|
||||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
|||||||
t.Errorf("Failed to stop server: %v", err)
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
waitForServer(t, addr)
|
||||||
|
|
||||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||||
fileContent := []byte("test file content")
|
fileContent := []byte("test file content")
|
||||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, fileContent, createdFileContent)
|
require.Equal(t, fileContent, createdFileContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||||
|
// address, then releases it so the server under test can rebind. The close/
|
||||||
|
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||||
|
// it's essentially never contended in practice.
|
||||||
|
func reserveLoopbackPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
addr := l.Addr().String()
|
||||||
|
require.NoError(t, l.Close())
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForServer(t *testing.T, addr string) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("server did not start listening on %s in time", addr)
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
|
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
|
|||||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, err := getOSDNSManagerType()
|
osManager, reason, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s", osManager)
|
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, error) {
|
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||||
|
resolved := isSystemdResolvedRunning()
|
||||||
|
nss := isLibnssResolveUsed()
|
||||||
|
stub := checkStub()
|
||||||
|
|
||||||
|
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||||
|
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||||
|
// that go through nss-resolve, and in foreign mode they can loop back
|
||||||
|
// through resolved as an upstream.
|
||||||
|
if resolved && (nss || stub) {
|
||||||
|
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, "", err
|
||||||
|
}
|
||||||
|
if reason != "" {
|
||||||
|
return mgr, reason, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||||
|
if len(rejected) > 0 {
|
||||||
|
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||||
|
}
|
||||||
|
return fileManager, fallback, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||||
|
// matching manager. If reason is empty the caller should pick file mode and
|
||||||
|
// use rejected for diagnostics.
|
||||||
|
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := file.Close(); err != nil {
|
if cerr := file.Close(); cerr != nil {
|
||||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var rejected []string
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text[0] != '#' {
|
if text[0] != '#' {
|
||||||
return fileManager, nil
|
break
|
||||||
}
|
}
|
||||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||||
return netbirdManager, nil
|
return mgr, reason, nil, nil
|
||||||
}
|
} else if rej != "" {
|
||||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
rejected = append(rejected, rej)
|
||||||
return networkManager, nil
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
|
||||||
if checkStub() {
|
|
||||||
return systemdManager, nil
|
|
||||||
} else {
|
|
||||||
return fileManager, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "resolvconf") {
|
|
||||||
if isSystemdResolveConfMode() {
|
|
||||||
return systemdManager, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolvConfManager, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
return 0, fmt.Errorf("scan: %w", err)
|
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||||
}
|
}
|
||||||
|
return 0, "", rejected, nil
|
||||||
return fileManager, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||||
|
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||||
|
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||||
|
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||||
|
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "NetworkManager") {
|
||||||
|
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||||
|
}
|
||||||
|
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "resolvconf") {
|
||||||
|
if isSystemdResolveConfMode() {
|
||||||
|
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||||
|
}
|
||||||
|
return resolvConfManager, "resolvconf header detected", ""
|
||||||
|
}
|
||||||
|
return 0, "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||||
|
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||||
|
// into file mode while resolved is active.
|
||||||
func checkStub() bool {
|
func checkStub() bool {
|
||||||
rConf, err := parseDefaultResolvConf()
|
rConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse resolv conf: %s", err)
|
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||||
|
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||||
|
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||||
|
func isLibnssResolveUsed() bool {
|
||||||
|
bs, err := os.ReadFile(nsswitchConfPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return parseNsswitchResolveAhead(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseNsswitchResolveAhead(data []byte) bool {
|
||||||
|
for _, line := range strings.Split(string(data), "\n") {
|
||||||
|
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||||
|
line = line[:i]
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, module := range fields[1:] {
|
||||||
|
switch module {
|
||||||
|
case "dns":
|
||||||
|
return false
|
||||||
|
case "resolve":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "resolve before dns with action token",
|
||||||
|
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dns before resolve",
|
||||||
|
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "debian default with only dns",
|
||||||
|
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "neither resolve nor dns",
|
||||||
|
in: "hosts: files myhostname\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no hosts line",
|
||||||
|
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
in: "",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "comments and blank lines ignored",
|
||||||
|
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing inline comment",
|
||||||
|
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hosts token must be the first field",
|
||||||
|
in: " hosts: resolve dns\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other db line mentioning resolve is ignored",
|
||||||
|
in: "networks: resolve\nhosts: dns\n",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only resolve, no dns",
|
||||||
|
in: "hosts: files resolve\n",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||||
|
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -604,6 +605,8 @@ func (e *Engine) createFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
firewalld.SetParentContext(e.ctx)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
|
|
||||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||||
|
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||||
|
// can outlive the test and call t.Logf after teardown, panicking.
|
||||||
|
receiveDone := make(chan struct{})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
select {
|
||||||
|
case <-receiveDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("Receive goroutine did not exit after Close")
|
||||||
|
}
|
||||||
|
})
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := client.Close()
|
err := client.Close()
|
||||||
assert.NoError(t, err, "failed to close flow")
|
assert.NoError(t, err, "failed to close flow")
|
||||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
|||||||
receivedAfterReconnect := make(chan struct{})
|
receivedAfterReconnect := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(receiveDone)
|
||||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
|
|||||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||||
echo "Registering CrowdSec bouncer..."
|
echo "Registering CrowdSec bouncer..."
|
||||||
local cs_retries=0
|
local cs_retries=0
|
||||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
||||||
cs_retries=$((cs_retries + 1))
|
cs_retries=$((cs_retries + 1))
|
||||||
if [[ $cs_retries -ge 30 ]]; then
|
if [[ $cs_retries -ge 30 ]]; then
|
||||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
@@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
switch authType {
|
switch authType {
|
||||||
case "bearer":
|
case "bearer":
|
||||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
h.ServeHTTP(w, request)
|
|
||||||
case "token":
|
case "token":
|
||||||
request, err := m.checkPATFromRequest(r, authHeader)
|
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||||
// Check if it's a status error, otherwise default to Unauthorized
|
// Check if it's a status error, otherwise default to Unauthorized
|
||||||
if _, ok := status.FromError(err); !ok {
|
if _, ok := status.FromError(err); !ok {
|
||||||
@@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.ServeHTTP(w, request)
|
h.ServeHTTP(w, r)
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||||
return
|
return
|
||||||
@@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CheckJWTFromRequest checks if the JWT is valid
|
// CheckJWTFromRequest checks if the JWT is valid
|
||||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
// If an error occurs, call the error handler and return an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||||
@@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if userAuth.AccountId != accountId {
|
if userAuth.AccountId != accountId {
|
||||||
@@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
|
|
||||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||||
@@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
|||||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
// propagates ctx change to upstream middleware
|
||||||
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.patUsageTracker != nil {
|
if m.patUsageTracker != nil {
|
||||||
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
|
|
||||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||||
if !m.rateLimiter.Allow(token) {
|
if !m.rateLimiter.Allow(token) {
|
||||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, fmt.Errorf("invalid Token: %w", err)
|
return fmt.Errorf("invalid Token: %w", err)
|
||||||
}
|
}
|
||||||
if time.Now().After(pat.GetExpirationDate()) {
|
if time.Now().After(pat.GetExpirationDate()) {
|
||||||
return r, fmt.Errorf("token expired")
|
return fmt.Errorf("token expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
userAuth := auth.UserAuth{
|
userAuth := auth.UserAuth{
|
||||||
@@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
// propagates ctx change to upstream middleware
|
||||||
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTerraformRequest(r *http.Request) bool {
|
func isTerraformRequest(r *http.Request) bool {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
var isUpdate = policy.ID != ""
|
var isUpdate = policy.ID != ""
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var action = activity.PolicyAdded
|
var action = activity.PolicyAdded
|
||||||
|
var unchanged bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
saveFunc := transaction.CreatePolicy
|
|
||||||
if isUpdate {
|
if isUpdate {
|
||||||
action = activity.PolicyUpdated
|
if policy.Equal(existingPolicy) {
|
||||||
saveFunc = transaction.SavePolicy
|
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||||
}
|
unchanged = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if err = saveFunc(ctx, policy); err != nil {
|
action = activity.PolicyUpdated
|
||||||
return err
|
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||||
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if unchanged {
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
|||||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||||
if isUpdate {
|
for _, rule := range policy.Rules {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !policy.Enabled && !existingPolicy.Enabled {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range existingPolicy.Rules {
|
|
||||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasPeers {
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||||
|
}
|
||||||
|
|
||||||
|
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||||
|
if !policy.Enabled && !existingPolicy.Enabled {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range existingPolicy.Rules {
|
||||||
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
|||||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||||
}
|
}
|
||||||
|
|
||||||
// validatePolicy validates the policy and its rules.
|
// validatePolicy validates the policy and its rules. For updates it returns
|
||||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
// the existing policy loaded from the store so callers can avoid a second read.
|
||||||
|
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||||
|
var existingPolicy *types.Policy
|
||||||
if policy.ID != "" {
|
if policy.ID != "" {
|
||||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
var err error
|
||||||
|
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Refactor to support multiple rules per policy
|
// TODO: Refactor to support multiple rules per policy
|
||||||
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
|
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
|
|
||||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, rule := range policy.Rules {
|
for i, rule := range policy.Rules {
|
||||||
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
|||||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return existingPolicy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||||
|
|||||||
@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
h.ServeHTTP(w, r.WithContext(ctx))
|
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
|
||||||
|
req := r.WithContext(ctx)
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
close(handlerDone)
|
close(handlerDone)
|
||||||
|
|
||||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
ctx = req.Context()
|
||||||
if err == nil {
|
|
||||||
if userAuth.AccountId != "" {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
|
||||||
}
|
|
||||||
if userAuth.UserId != "" {
|
|
||||||
//nolint
|
|
||||||
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.Status() > 399 {
|
if w.Status() > 399 {
|
||||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||||
|
|||||||
@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Policy) Equal(other *Policy) bool {
|
||||||
|
if p == nil || other == nil {
|
||||||
|
return p == other
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.ID != other.ID ||
|
||||||
|
p.AccountID != other.AccountID ||
|
||||||
|
p.Name != other.Name ||
|
||||||
|
p.Description != other.Description ||
|
||||||
|
p.Enabled != other.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.Rules) != len(other.Rules) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
otherRules := make(map[string]*PolicyRule, len(other.Rules))
|
||||||
|
for _, r := range other.Rules {
|
||||||
|
otherRules[r.ID] = r
|
||||||
|
}
|
||||||
|
for _, r := range p.Rules {
|
||||||
|
otherRule, ok := otherRules[r.ID]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !r.Equal(otherRule) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event meta related to this policy
|
// EventMeta returns activity event meta related to this policy
|
||||||
func (p *Policy) EventMeta() map[string]any {
|
func (p *Policy) EventMeta() map[string]any {
|
||||||
return map[string]any{"name": p.Name}
|
return map[string]any{"name": p.Name}
|
||||||
|
|||||||
193
management/server/types/policy_test.go
Normal file
193
management/server/types/policy_test.go
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentRules(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
{ID: "r2", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc3"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
|
||||||
|
base := Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "test",
|
||||||
|
Description: "desc",
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
other := base
|
||||||
|
other.Name = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Enabled = false
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Description = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_NilCases(t *testing.T) {
|
||||||
|
var a *Policy
|
||||||
|
var b *Policy
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
|
||||||
|
a = &Policy{ID: "pol1"}
|
||||||
|
assert.False(t, a.Equal(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r1", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{ID: "r2", PolicyID: "pol1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyEqual_FullScenario(t *testing.T) {
|
||||||
|
a := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "Web Access",
|
||||||
|
Description: "Allow web access",
|
||||||
|
Enabled: true,
|
||||||
|
SourcePostureChecks: []string{"pc2", "pc1"},
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "r1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "HTTP",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"g2", "g1"},
|
||||||
|
Destinations: []string{"g4", "g3"},
|
||||||
|
Ports: []string{"443", "80", "8080"},
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &Policy{
|
||||||
|
ID: "pol1",
|
||||||
|
AccountID: "acc1",
|
||||||
|
Name: "Web Access",
|
||||||
|
Description: "Allow web access",
|
||||||
|
Enabled: true,
|
||||||
|
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "r1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "HTTP",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"g1", "g2"},
|
||||||
|
Destinations: []string{"g3", "g4"},
|
||||||
|
Ports: []string{"80", "8080", "443"},
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
|||||||
}
|
}
|
||||||
return rule
|
return rule
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||||
|
if pm == nil || other == nil {
|
||||||
|
return pm == other
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.ID != other.ID ||
|
||||||
|
pm.PolicyID != other.PolicyID ||
|
||||||
|
pm.Name != other.Name ||
|
||||||
|
pm.Description != other.Description ||
|
||||||
|
pm.Enabled != other.Enabled ||
|
||||||
|
pm.Action != other.Action ||
|
||||||
|
pm.Bidirectional != other.Bidirectional ||
|
||||||
|
pm.Protocol != other.Protocol ||
|
||||||
|
pm.SourceResource != other.SourceResource ||
|
||||||
|
pm.DestinationResource != other.DestinationResource ||
|
||||||
|
pm.AuthorizedUser != other.AuthorizedUser {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSlicesEqualUnordered(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
sorted1 := make([]string, len(a))
|
||||||
|
sorted2 := make([]string, len(b))
|
||||||
|
copy(sorted1, a)
|
||||||
|
copy(sorted2, b)
|
||||||
|
slices.Sort(sorted1)
|
||||||
|
slices.Sort(sorted2)
|
||||||
|
return slices.Equal(sorted1, sorted2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cmp := func(x, y RulePortRange) int {
|
||||||
|
if x.Start != y.Start {
|
||||||
|
if x.Start < y.Start {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if x.End != y.End {
|
||||||
|
if x.End < y.End {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
sorted1 := make([]RulePortRange, len(a))
|
||||||
|
sorted2 := make([]RulePortRange, len(b))
|
||||||
|
copy(sorted1, a)
|
||||||
|
copy(sorted2, b)
|
||||||
|
slices.SortFunc(sorted1, cmp)
|
||||||
|
slices.SortFunc(sorted2, cmp)
|
||||||
|
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
|
||||||
|
return x.Start == y.Start && x.End == y.End
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func authorizedGroupsEqual(a, b map[string][]string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, va := range a {
|
||||||
|
vb, ok := b[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSlicesEqualUnordered(va, vb) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
194
management/server/types/policyrule_test.go
Normal file
194
management/server/types/policyrule_test.go
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "80", "22"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"22", "443", "80"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "80"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{"443", "22"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g2", "g3"},
|
||||||
|
Destinations: []string{"g4", "g5"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g3", "g1", "g2"},
|
||||||
|
Destinations: []string{"g5", "g4"},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g2"},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Sources: []string{"g1", "g3"},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
{Start: 8000, End: 9000},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 80},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
PortRanges: []RulePortRange{
|
||||||
|
{Start: 80, End: 443},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u1", "u2", "u3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u3", "u1", "u2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g1": {"u1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
AuthorizedGroups: map[string][]string{
|
||||||
|
"g2": {"u1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.False(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
|
||||||
|
base := PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Name: "test",
|
||||||
|
Description: "desc",
|
||||||
|
Enabled: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
}
|
||||||
|
|
||||||
|
other := base
|
||||||
|
other.Name = "changed"
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Enabled = false
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Action = PolicyTrafficActionDrop
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
|
||||||
|
other = base
|
||||||
|
other.Protocol = PolicyRuleProtocolUDP
|
||||||
|
assert.False(t, base.Equal(&other))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_NilCases(t *testing.T) {
|
||||||
|
var a *PolicyRule
|
||||||
|
var b *PolicyRule
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
|
||||||
|
a = &PolicyRule{ID: "rule1"}
|
||||||
|
assert.False(t, a.Equal(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
|
||||||
|
a := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: []string{},
|
||||||
|
Sources: nil,
|
||||||
|
}
|
||||||
|
b := &PolicyRule{
|
||||||
|
ID: "rule1",
|
||||||
|
PolicyID: "pol1",
|
||||||
|
Ports: nil,
|
||||||
|
Sources: []string{},
|
||||||
|
}
|
||||||
|
assert.True(t, a.Equal(b))
|
||||||
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user