Compare commits

...

27 Commits

Author SHA1 Message Date
mlsmaycon
5eb28acb11 [management] Account-scoped ephemeral peer cleanup
Replace the per-peer linked list with a per-account map keyed by
accountID. Each entry holds only the latest disconnect timestamp we
have observed for that account and a single timer that fires the next
sweep. Sweeps query the database for the authoritative stale set,
batch the deletes through peers.Manager.DeletePeers, then drop the
account from the tracker when lastDisc + lifeTime <= now (else
re-arm at horizon + cleanupWindow).

The drop rule is the entire termination story: an account stays
tracked only while OnPeerDisconnected keeps refreshing the
timestamp. There is no internal feedback loop that can advance
lastDisc on its own, so once disconnects stop the account drops in
at most one sweep.

A timestamp beats the ref-counter alternative because the counter
drifts positive in three real situations the cleanup loop has no
signal for: peers deleted via the API while offline, peers that
reconnect within the lifetime window, and management restarts. The
timestamp design never claims to know the size of the stale set —
it only knows the latest disconnect we observed and uses that to
bound when it is safe to drop the account.

OnPeerConnected becomes a no-op. The sweep query already filters
reconnected peers at the database level (peer_status_connected =
false in the WHERE clause), so there is nothing the in-memory
tracker needs to do on reconnect. The interface method is preserved
for call-site compatibility.

LoadInitialPeers no longer runs the catch-up query synchronously.
It schedules a deferred load via time.AfterFunc at a random delay
between 8 and 10 minutes. Without the jitter, every management
replica in a fleet-wide deploy would issue the catch-up query
simultaneously. The catch-up itself is one GROUP BY against the
peers table:
```sql
  SELECT account_id, MAX(peer_status_last_seen)
  FROM peers
  WHERE ephemeral = true AND peer_status_connected = false
  GROUP BY account_id
```
For each row the tracker seeds an entry and arms a sweep at
max(now, last_seen + lifeTime) + cleanupWindow — so accounts whose
backlog is already stale get cleaned soon after the delay elapses,
and accounts that disconnected recently wait the remaining window.
OnPeerDisconnected calls that arrive during the delay window seed
the tracker live, and the catch-up query skips accounts that are
already tracked.

Stop() cancels both the deferred initial-load timer and every
per-account sweep timer, and flips a stopped flag so subsequent
OnPeerDisconnected calls are ignored. This makes restarts and test
teardown clean.

Two new store methods:
  GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
  GetEphemeralAccountsLastDisconnect(ctx)
Both are scoped, indexable queries that the existing peers table
supports without schema changes.

The pending metric is renamed from
management.ephemeral.peers.pending to
management.ephemeral.accounts.tracked to reflect the new semantics
(it now counts accounts on the cleanup list, not peers). Method
names on the metrics type are unchanged so no production call site
has to move. No new metric labels, no per-account cardinality.

The algorithm was validated against an in-memory SQLite peers
table through an 11-scenario prototype kept under proto/, including
pathological-churn and 4-hour randomized simulations. All scenarios
terminate; max observed per-account sweep rate stays bounded near
the lifeTime + cleanupWindow cadence even under sustained
disconnect churn.

Verification: go build, go vet, race-clean tests across the
ephemeral, store, and telemetry packages, plus a clean
golangci-lint pass on the touched packages.
2026-05-19 09:50:14 +02:00
Maycon Santos
af24fd7796 [management] Add metrics for peer status updates and ephemeral cleanup (#6196)
* [management] Add metrics for peer status updates and ephemeral cleanup

The session-fenced MarkPeerConnected / MarkPeerDisconnected path and
the ephemeral peer cleanup loop both run silently today: when fencing
rejects a stale stream, when a cleanup tick deletes peers, or when a
batch delete fails, we have no operational signal beyond log lines.

Add OpenTelemetry counters and a histogram so the same SLO-style
dashboards that already exist for the network-map controller can cover
peer connect/disconnect and ephemeral cleanup too.

All new attributes are bounded enums: operation in {connect,disconnect}
and outcome in {applied,stale,error,peer_not_found}. No account, peer,
or user ID is ever written as a metric label — total cardinality is
fixed at compile time (8 counter series, 2 histogram series, 4 unlabeled
ephemeral series).

Metric methods are nil-receiver safe so test composition that doesn't
wire telemetry (the bulk of the existing tests) works unchanged. The
ephemeral manager exposes a SetMetrics setter rather than taking the
collector through its constructor, keeping the constructor signature
stable across all test call sites.

* [management] Add OpenTelemetry metrics for ephemeral peer cleanup

Introduce counters for tracking ephemeral peer cleanup, including peers pending deletion, cleanup runs, successful deletions, and failed batches. Metrics are nil-receiver safe to ensure compatibility with test setups without telemetry.
2026-05-18 22:55:19 +02:00
Maycon Santos
13d32d274f [management] Fence peer status updates with a session token (#6193)
* [management] Fence peer status updates with a session token

The connect/disconnect path used a best-effort LastSeen-after-streamStart
comparison to decide whether a status update should land. Under contention
— a re-sync arriving while the previous stream's disconnect was still in
flight, or two management replicas seeing the same peer at once — the
check was a read-then-decide-then-write window: any UPDATE in between
caused the wrong row to be written. The Go-side time.Now() that fed the
comparison also drifted under lock contention, since it was captured
seconds before the write actually committed.

Replace it with an integer-nanosecond fencing token stored alongside the
status. Every gRPC sync stream uses its open time (UnixNano) as its token.
Connects only land when the incoming token is strictly greater than the
stored one; disconnects only land when the incoming token equals the
stored one (i.e. we're the stream that owns the current session). Both
are single optimistic-locked UPDATEs — no read-then-write, no transaction
wrapper.

LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The
caller never supplies it, so the value always reflects the real moment
of the UPDATE rather than the moment the caller queued the work — which
was already off by minutes under heavy lock contention.

Side effects (geo lookup, peer-login-expiration scheduling, network-map
fan-out) are explicitly documented as running after the fence UPDATE
commits, never inside it. Geo also skips the update when realIP equals
the stored ConnectionIP, dropping a redundant SavePeerLocation call on
same-IP reconnects.

Tests cover the three semantic cases (matched disconnect lands, stale
disconnect dropped, stale connect dropped) plus a 16-goroutine race test
that asserts the highest token always wins.

* [management] Add SessionStartedAt to peer status updates

Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety.

* Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields
2026-05-18 20:25:12 +02:00
Nicolas Frati
705f87fc20 [management] fix: device redirect uri wasn't registered (#6191)
* fix: device redirect uri wasn't registered

* fix lint
2026-05-18 12:57:59 +02:00
Viktor Liu
3f91f49277 Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) 2026-05-16 16:52:57 +02:00
Maycon Santos
347c5bf317 Avoid context cancellation in cancelPeerRoutines (#6175)
When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation.

This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation.
2026-05-16 16:29:01 +02:00
Viktor Liu
22e2519d71 [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) 2026-05-16 15:51:48 +02:00
Vlad
e916f12cca [proxy] auth token generation on mapping (#6157)
* [management / proxy] auth token generation on mapping

* fix tests
2026-05-15 19:13:44 +02:00
Viktor Liu
9ed2e2a5b4 [client] Drop DNS probes for passive health projection (#5971) 2026-05-15 17:07:38 +02:00
Viktor Liu
2ccae7ec47 [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) 2026-05-15 16:58:47 +02:00
Viktor Liu
07e5450117 [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) 2026-05-14 16:42:40 +02:00
Viktor Liu
3f914090cb [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) 2026-05-14 16:22:53 +02:00
Viktor Liu
ea9fab4396 [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) 2026-05-14 16:05:33 +02:00
Vlad
77b479286e [management] fix offline statuses for public proxy clusters (#6133) 2026-05-14 13:27:50 +02:00
Maycon Santos
ab2a8794e7 [client] Add short flags for status command options (#6137)
* [client] Add short flags for status command options

* uppercase filters
2026-05-14 12:30:42 +02:00
Viktor Liu
9126a192ca [client] Set 0644 perms on SSH client config after os.CreateTemp (#6126) 2026-05-12 15:05:53 +02:00
Viktor Liu
1224d6e1ee [client] Persist management URL and pre-shared key overrides on login (#6065) 2026-05-12 14:52:56 +02:00
Nicolas Frati
96672dd1f8 [management] chores: update dex version (#6124)
* chores: update dex version

* chore: update dex fork
2026-05-12 13:50:35 +02:00
Viktor Liu
946ce4c3da [client] Fix --config flag default to point at profile path (#6122) 2026-05-11 17:48:21 +02:00
Vlad
07cbfdbede [proxy] feature: bring your own proxy (#5627) 2026-05-11 14:31:38 +02:00
Viktor Liu
a4114a5e45 [client] Skip DNS upstream failover on definitive EDE (#6089) 2026-05-11 10:00:23 +02:00
Viktor Liu
6b08e89c7b [relay] Preserve non-standard port in WS dialer URL prep (#6061) 2026-05-11 09:59:33 +02:00
Viktor Liu
a852b3bd34 [client, proxy] Harden uspfilter conntrack and share TCP relay (#5936) 2026-05-11 09:59:13 +02:00
Viktor Liu
afb83b3049 [client] Use unique temp file and clean up on failure when writing ssh config (#6064) 2026-05-11 09:58:49 +02:00
Nicolas Frati
e89aad09f5 [management] Enable MFA for local users (#5804)
* wip: totp for local users

* fix providers not getting populated

* polished UI and fix post_login_redirect_uri

* fix: make sure logout is only prompted from oidc flow

Signed-off-by: jnfrati <nicofrati@gmail.com>

* update templates

Signed-off-by: jnfrati <nicofrati@gmail.com>

* deps: update dex dependency

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fix qube issues

Signed-off-by: jnfrati <nicofrati@gmail.com>

* replace window with globalThis on home html

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fixed coderabbit comments

Signed-off-by: jnfrati <nicofrati@gmail.com>

* debug

* remove unused config and rename totp issuer

* deps: update dex reference to latest

* add dashboard post logout redirect uri to embedded config

* implemented api for mfa configuration

* update docs and config parsing

* catch error on idp manager init mfa

* fix tests

* Add remember me  for MFA

* Add cookie encryption and session share between tabs

* fixed logout showing non actionable error and session cookie encription key

* fixed missing mfa settings on sql query for account

* fix code index for mfa activity

---------

Signed-off-by: jnfrati <nicofrati@gmail.com>
Co-authored-by: braginini <bangvalo@gmail.com>
2026-05-08 16:31:20 +02:00
Maycon Santos
7da94a4956 [misc] Update CONTRIBUTING.md (#6076) 2026-05-07 16:16:48 +02:00
Pascal Fischer
39eac377e4 [management] add update reason to buffered calls (#6103) 2026-05-07 15:55:59 +02:00
139 changed files with 9999 additions and 2293 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -8,7 +8,7 @@ There are many ways that you can contribute:
- Sharing use cases in slack or Reddit
- Bug fix or feature enhancement
If you haven't already, join our slack workspace [here](https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A), we would love to discuss topics that need community contribution and enhancements to existing features.
If you haven't already, join our slack workspace [here](https://docs.netbird.io/slack-url), we would love to discuss topics that need community contribution and enhancements to existing features.
## Contents

View File

@@ -143,7 +143,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location")
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)

View File

@@ -43,16 +43,16 @@ func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {

View File

@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {

View File

@@ -0,0 +1,125 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
)
func TestTCPCapEvicts(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "4")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 4, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 10; i++ {
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
}
require.LessOrEqual(t, len(tracker.connections), 4,
"TCP table must not exceed the configured cap")
require.Greater(t, len(tracker.connections), 0,
"some entries must remain after eviction")
// The most recently admitted flow must be present: eviction must make
// room for new entries, not silently drop them.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
"newest TCP flow must be admitted after eviction")
// A pre-cap flow must have been evicted to fit the last one.
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
"oldest TCP flow should have been evicted")
}
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "3")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
// Fill to cap with 3 live connections.
for i := 0; i < 3; i++ {
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
}
require.Len(t, tracker.connections, 3)
// Tombstone one by sending RST through IsValidInbound.
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
// Another live connection forces eviction. The tombstone must go first.
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
require.False(t, tombstonedStillPresent,
"tombstoned entry should be evicted before live entries")
require.LessOrEqual(t, len(tracker.connections), 3)
// Both live pre-cap entries must survive: eviction must prefer the
// tombstone, not just satisfy the size bound by dropping any entry.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
}
func TestUDPCapEvicts(t *testing.T) {
t.Setenv(EnvUDPMaxEntries, "5")
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 5, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 12; i++ {
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
}
require.LessOrEqual(t, len(tracker.connections), 5)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
"newest UDP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
"oldest UDP flow should have been evicted")
}
func TestICMPCapEvicts(t *testing.T) {
t.Setenv(EnvICMPMaxEntries, "3")
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 3, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
for i := 0; i < 8; i++ {
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
}
require.LessOrEqual(t, len(tracker.connections), 3)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
"newest ICMP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
"oldest ICMP flow should have been evicted")
}

View File

@@ -3,15 +3,61 @@ package conntrack
import (
"net"
"net/netip"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// evictSampleSize bounds how many map entries we scan per eviction call.
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
// heuristic is good enough for a conntrack table that only overflows under
// abuse.
const evictSampleSize = 8
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
// def on empty or invalid; logs a warning on invalid.
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
v := os.Getenv(name)
if v == "" {
return def
}
d, err := time.ParseDuration(v)
if err != nil {
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
}
if d <= 0 {
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return d
}
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
// invalid, or non-positive. Logs a warning on invalid input.
func envInt(logger *nblog.Logger, name string, def int) int {
v := os.Getenv(name)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
switch {
case err != nil:
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
case n <= 0:
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return n
}
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID

View File

@@ -0,0 +1,11 @@
//go:build !ios && !android
package conntrack
// Default per-tracker entry caps on desktop/server platforms. These mirror
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
const (
DefaultMaxTCPEntries = 65536
DefaultMaxUDPEntries = 16384
DefaultMaxICMPEntries = 2048
)

View File

@@ -0,0 +1,13 @@
//go:build ios || android
package conntrack
// Default per-tracker entry caps on mobile platforms. iOS network extensions
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
// is ~200 B plus map overhead).
const (
DefaultMaxTCPEntries = 4096
DefaultMaxUDPEntries = 2048
DefaultMaxICMPEntries = 512
)

View File

@@ -50,6 +50,9 @@ type ICMPConnTrack struct {
ICMPCode uint8
}
// EnvICMPMaxEntries caps the ICMP conntrack table size.
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
@@ -58,6 +61,7 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -171,6 +175,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
flowLogger: flowLogger,
}
@@ -257,7 +262,9 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -276,10 +283,15 @@ func (t *ICMPTracker) track(
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
@@ -323,6 +335,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
}
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
func (t *ICMPTracker) evictOneLocked() {
var candKey ICMPConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
@@ -331,8 +371,10 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -38,6 +38,27 @@ const (
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
FinWaitTimeout = 60 * time.Second
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
// holding CloseWait longer than this should bump the env var.
CloseWaitTimeout = 60 * time.Second
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
LastAckTimeout = 30 * time.Second
)
// Env vars to override per-state teardown timeouts. Values parsed by
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
// defaults above with a warning.
const (
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
// (tombstones first) are evicted when the cap is reached.
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
)
// TCPState represents the state of a TCP connection
@@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() {
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
finWaitTimeout time.Duration
closeWaitTimeout time.Duration
lastAckTimeout time.Duration
maxEntries int
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
@@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger,
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
@@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
if exists || flags&TCPSyn == 0 {
return
}
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
// create spurious conntrack entries. Not mandated by RFC 9293 but a
// common hardening (Linux netfilter/nftables rejects these too).
if !isValidFlagCombination(flags) {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
@@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
// iteration order is randomized), preferring a tombstone. If no tombstone
// found in the sample, evicts the oldest among the sampled entries. O(1)
// worst case — cheap enough to run on every insert at cap during abuse.
func (t *TCPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
if c.IsTombstone() {
delete(t.connections, k)
return
}
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
// TypeEnd is already emitted at the state transition to
// TimeWait and when a connection is tombstoned. Only emit
// here when we're reaping a still-active flow.
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
}
delete(t.connections, candKey)
}
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
@@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return false
}
// Reject illegal flag combinations regardless of state. These never belong
// to a legitimate flow and must not advance or tear down state.
if !isValidFlagCombination(flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
}
return false
}
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
}
return false
}
@@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return true
}
// updateState updates the TCP connection state based on flags
// updateState updates the TCP connection state based on flags.
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
// Malformed flag combinations must not refresh lastSeen or drive state,
// otherwise spoofed packets keep a dead flow alive past its timeout.
if !isValidFlagCombination(flags) {
return
}
conn.UpdateLastSeen()
currentState := conn.GetState()
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
// cannot apply the RFC 5961 in-window RST check, so we conservatively
// reject RSTs that the spec would accept (TIME-WAIT with in-window
// SEQ, SynSent from same direction as own SYN, etc.).
t.handleRst(key, conn, currentState, packetDir)
return
}
var newState TCPState
switch currentState {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
}
newState := nextState(currentState, conn.Direction, packetDir, flags)
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
return
}
t.onTransition(key, conn, currentState, newState, packetDir)
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction {
newState = TCPStateEstablished
} else {
// Simultaneous open
newState = TCPStateSynReceived
}
}
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
// from the SYN direction are ignored; otherwise the flow is tombstoned.
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
// TimeWait exists to absorb late segments; don't let a late RST
// tombstone the entry and break same-4-tuple reuse.
if currentState == TCPStateTimeWait {
return
}
// A RST from the same direction as the SYN cannot be a legitimate
// response and must not tear down a half-open connection.
if currentState == TCPStateSynSent && packetDir == conn.Direction {
return
}
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
return
}
conn.SetTombstone()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
if packetDir == conn.Direction {
newState = TCPStateEstablished
}
}
// stateTransition describes one state's transition logic. It receives the
// packet's flags plus whether the packet direction matches the connection's
// origin direction (same=true means same side as the SYN initiator). Return 0
// for no transition.
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
case TCPStateEstablished:
if flags&TCPFin != 0 {
if packetDir == conn.Direction {
newState = TCPStateFinWait1
} else {
newState = TCPStateCloseWait
}
}
// stateTable maps each state to its transition function. Centralized here so
// nextState stays trivial and each rule is easy to read in isolation.
var stateTable = map[TCPState]stateTransition{
TCPStateNew: transNew,
TCPStateSynSent: transSynSent,
TCPStateSynReceived: transSynReceived,
TCPStateEstablished: transEstablished,
TCPStateFinWait1: transFinWait1,
TCPStateFinWait2: transFinWait2,
TCPStateClosing: transClosing,
TCPStateCloseWait: transCloseWait,
TCPStateLastAck: transLastAck,
}
case TCPStateFinWait1:
if packetDir != conn.Direction {
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0:
newState = TCPStateClosing
case flags&TCPAck != 0:
newState = TCPStateFinWait2
}
}
// nextState returns the target TCP state for the given current state and
// packet, or 0 if the packet does not trigger a transition.
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
fn, ok := stateTable[currentState]
if !ok {
return 0
}
return fn(flags, connDir, packetDir == connDir)
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
newState = TCPStateTimeWait
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if connDir == nftypes.Egress {
return TCPStateSynSent
}
return TCPStateSynReceived
}
return 0
}
case TCPStateClosing:
if flags&TCPAck != 0 {
newState = TCPStateTimeWait
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if same {
return TCPStateSynReceived // simultaneous open
}
return TCPStateEstablished
}
return 0
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
newState = TCPStateLastAck
}
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
return TCPStateEstablished
}
return 0
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
newState = TCPStateClosed
}
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin == 0 {
return 0
}
if same {
return TCPStateFinWait1
}
return TCPStateCloseWait
}
// transFinWait1 handles the active-close peer response. A FIN carrying our
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
if same {
return 0
}
if flags&TCPFin != 0 && flags&TCPAck != 0 {
return TCPStateTimeWait
}
switch {
case flags&TCPFin != 0:
return TCPStateClosing
case flags&TCPAck != 0:
return TCPStateFinWait2
}
return 0
}
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transClosing completes a simultaneous close on the peer's ACK.
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && same {
return TCPStateLastAck
}
return 0
}
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateClosed
}
return 0
}
// onTransition handles logging and flow-event emission after a successful
// state transition. TimeWait and Closed are terminal for flow accounting.
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
traceOn := t.logger.Enabled(nblog.LevelTrace)
if traceOn {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
switch to {
case TCPStateTimeWait:
if traceOn {
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
if traceOn {
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
// isValidStateForFlags checks if the TCP flags are valid for the current
// connection state. Caller must have already verified the flag combination is
// legal via isValidFlagCombination.
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
@@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() {
timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
timeout = t.finWaitTimeout
case TCPStateCloseWait:
timeout = t.closeWaitTimeout
case TCPStateLastAck:
timeout = t.lastAckTimeout
default:
// SynSent / SynReceived / New
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
delete(t.connections, key)
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
// event already handled by state change
if currentState != TCPStateTimeWait {

View File

@@ -0,0 +1,100 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// RST hygiene tests: the tracker currently closes the flow on any RST that
// matches the 4-tuple, regardless of direction or state. These tests cover
// the minimum checks we want (no SEQ tracking).
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateSynSent, conn.GetState())
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
// cannot be a legitimate response. It must not close the connection.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
require.Equal(t, TCPStateSynSent, conn.GetState(),
"RST in same direction as SYN must not close connection")
require.False(t, conn.IsTombstone())
}
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Drive to TIME-WAIT via active close.
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
// exists to absorb late segments).
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState(),
"RST in TIME-WAIT must not transition state")
require.False(t, conn.IsTombstone(),
"RST in TIME-WAIT must not tombstone the entry")
}
func TestTCPIllegalFlagCombos(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
// Illegal combos must be rejected and must not change state.
combos := []struct {
name string
flags uint8
}{
{"SYN+RST", TCPSyn | TCPRst},
{"FIN+RST", TCPFin | TCPRst},
{"SYN+FIN", TCPSyn | TCPFin},
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
}
for _, c := range combos {
t.Run(c.name, func(t *testing.T) {
before := conn.GetState()
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
require.Equal(t, before, conn.GetState(),
"illegal flag combo must not change state")
require.False(t, conn.IsTombstone())
})
}
}

View File

@@ -0,0 +1,235 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// These tests exercise cases where the TCP state machine currently advances
// on retransmitted or wrong-direction segments and tears the flow down
// prematurely. They are expected to fail until the direction checks are added.
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer sends FIN -> CloseWait (our app has not yet closed).
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
// sent our FIN yet, so state must remain CloseWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "retransmitted peer FIN must still be accepted")
require.Equal(t, TCPStateCloseWait, conn.GetState(),
"retransmitted peer FIN must not advance CloseWait to LastAck")
// Our app finally closes -> LastAck.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Peer ACK closes.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// We initiate close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Stray retransmit of our own FIN (same direction as originator) must
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait2, conn.GetState(),
"own FIN retransmit must not advance FinWait2 to TimeWait")
// Peer FIN -> TimeWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
}
func TestTCPLastAckDirectionCheck(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateLastAck, conn.GetState())
// Our own ACK retransmit (same direction as originator) must NOT close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState(),
"own ACK retransmit in LastAck must not transition to Closed")
// Peer's ACK -> Closed.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Our own ACK retransmit (same direction as originator) must not advance.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState(),
"own ACK in FinWait1 must not advance to FinWait2")
}
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
// Verify cleanup reaps entries in each teardown state at the configured
// per-state timeout, not at the single handshake timeout.
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
t.Setenv(EnvTCPLastAckTimeout, "30ms")
dstIP := netip.MustParseAddr("100.64.0.2")
dstPort := uint16(80)
// Drives a connection to the target state, forces its lastSeen well
// beyond the configured timeout, runs cleanup, and asserts reaping.
cases := []struct {
name string
// drive takes a fresh tracker and returns the conn key after
// transitioning the flow into the intended teardown state.
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
}{
{
name: "FinWait1",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
},
},
{
name: "FinWait2",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
},
},
{
name: "CloseWait",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
},
},
{
name: "LastAck",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
},
},
}
// Use a unique source port per subtest so nothing aliases.
port := uint16(12345)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
srcIP := netip.MustParseAddr("100.64.0.1")
port++
key, wantState := c.drive(t, tracker, srcIP, port)
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, wantState, conn.GetState())
// Age the entry past the largest per-state timeout.
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
tracker.cleanup()
_, exists := tracker.connections[key]
require.False(t, exists, "%s entry should be reaped", c.name)
})
}
}
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
// teardown states, which some stacks emit during close.
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer FIN -> CloseWait.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
// Peer pushes trailing data + FIN|PSH|ACK (legal).
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
"FIN|PSH|ACK in CloseWait must be accepted")
// Bare ACK keepalive from peer in CloseWait must be accepted.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
"bare ACK in CloseWait must be accepted")
}

View File

@@ -17,6 +17,9 @@ const (
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
// EnvUDPMaxEntries caps the UDP conntrack table size.
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
)
// UDPConnTrack represents a UDP connection state
@@ -34,6 +37,7 @@ type UDPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
flowLogger: flowLogger,
}
@@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
@@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return true
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample: picks the oldest among up to evictSampleSize entries.
func (t *UDPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
@@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -787,7 +787,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
}
return false
}
@@ -901,7 +903,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
}
return true
}
@@ -1044,11 +1048,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder
if fragment {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
if m.logger.Enabled(nblog.LevelTrace) {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
}
return false
}
@@ -1091,8 +1097,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1142,8 +1150,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
}
return true
}
@@ -1160,8 +1170,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1287,7 +1299,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.decodePacket(packetData); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace1("couldn't decode packet, err: %s", err)
}
return false, false
}

View File

@@ -13,6 +13,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -97,8 +98,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return conn, nil
}
@@ -121,12 +124,14 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
txBytes := f.handleEchoResponse(conn, id, v6)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
proto := "ICMP"
if v6 {
proto = "ICMPv6"
if f.logger.Enabled(nblog.LevelTrace) {
proto := "ICMP"
if v6 {
proto = "ICMPv6"
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
@@ -224,13 +229,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
}
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}

View File

@@ -1,11 +1,8 @@
package forwarder
import (
"context"
"io"
"net"
"strconv"
"sync"
"github.com/google/uuid"
@@ -15,7 +12,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/util/netrelay"
)
// handleTCP is called by the TCP forwarder for new connections.
@@ -37,7 +36,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
}
return
}
@@ -60,64 +61,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
}
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
// netrelay.Relay copies bidirectionally with proper half-close propagation
// and fully closes both conns before returning.
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
Logger: f.logger,
})
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
// Close the netstack endpoint after both conns are drained.
ep.Close()
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
@@ -126,7 +85,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
}
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}

View File

@@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
}
@@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
}
return true
}
@@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.Unlock()
success = true
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
}
go f.proxyUDP(connCtx, pConn, id, ep)
return true
@@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
}
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)

View File

@@ -53,16 +53,17 @@ var levelStrings = map[Level]string{
}
type logMessage struct {
level Level
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
level Level
argCount uint8
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
// Enabled reports whether the given level is currently logged. Callers on the
// hot path should guard log sites with this to avoid boxing arguments into
// any when the level is off.
func (l *Logger) Enabled(level Level) bool {
return l.level.Load() >= uint32(level)
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) {
func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) {
func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) {
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
// to avoid allocating an args slice on the fast path when the arg count is
// known (0-3). Args beyond 3 land on the general variadic path; callers on
// the hot path should prefer DebugN for known counts.
func (l *Logger) Debugf(format string, args ...any) {
if l.level.Load() < uint32(LevelDebug) {
return
}
switch len(args) {
case 0:
l.Debug(format)
case 1:
l.Debug1(format, args[0])
case 2:
l.Debug2(format, args[0], args[1])
case 3:
l.Debug3(format, args[0], args[1], args[2])
default:
l.sendVariadic(LevelDebug, format, args)
}
}
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
// beyond the 8-arg slot limit are dropped so callers don't produce silently
// empty log lines via uint8 wraparound in argCount.
func (l *Logger) sendVariadic(level Level, format string, args []any) {
const maxArgs = 8
n := len(args)
if n > maxArgs {
n = maxArgs
}
msg := logMessage{level: level, argCount: uint8(n), format: format}
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
for i := 0; i < n; i++ {
*slots[i] = args[i]
}
select {
case l.msgChannel <- msg:
default:
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default:
}
}
@@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default:
}
}
@@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
@@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ')
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
}
}
}
var formatted string
switch argCount {
switch msg.argCount {
case 0:
formatted = msg.format
case 1:

View File

@@ -11,6 +11,7 @@ import (
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
var (
@@ -262,11 +263,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet destination: %v", err)
}
return false
}
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
}
return true
}
@@ -283,11 +288,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet source: %v", err)
}
return false
}
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
return true
}
@@ -612,7 +621,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite port: %v", err)
}
return false
}
d.dnatOrigPort = rule.origPort

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")

View File

@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization.
Firewall Rules (Linux only)
The bundle includes two separate firewall rule files:
The bundle includes the following firewall-related files:
iptables.txt:
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
- Includes all tables (filter, nat, mangle, raw, security)
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
ip6tables.txt:
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
- Same table coverage and anonymization as iptables.txt
- Omitted when ip6tables is not installed or no IPv6 rules are present
ipset.txt:
- Output of 'ipset list' (family-agnostic)
- IP addresses are anonymized; set names and types remain unchanged
nftables.txt:
- Complete nftables ruleset obtained via 'nft -a list ruleset'
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
- Includes rule handle numbers and packet counters
- All tables, chains, and rules are included
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
- All IP addresses are anonymized; chain/table names remain unchanged
sysctls.txt:
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
- Interface names anonymized when --anonymize is set
IP Rules (Linux only)
The ip_rules.txt file contains detailed IP routing rule information:
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addSysctls(); err != nil {
log.Errorf("failed to add sysctls to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}

View File

@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
// addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
iptablesRules, err := collectIPTablesRules()
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
log.Warnf("Failed to collect ipset information: %v", err)
} else {
if g.anonymize {
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
}
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
log.Warnf("Failed to add ipset output to bundle: %v", err)
}
}
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
return nil
}
// collectIPTablesRules collects rules using both iptables-save and verbose listing
func collectIPTablesRules() (string, error) {
var builder strings.Builder
saveOutput, err := collectIPTablesSave()
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
rules, err := collectIPTablesRules(saveBin, listBin)
if err != nil {
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
} else {
builder.WriteString("=== iptables-save output ===\n")
log.Warnf("Failed to collect %s rules: %v", listBin, err)
return
}
if g.anonymize {
rules = g.anonymizer.AnonymizeString(rules)
}
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
}
}
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
// Returns an error when neither command produced any output (e.g. the binary is missing),
// so the caller can skip writing an empty file.
func collectIPTablesRules(saveBin, listBin string) (string, error) {
var builder strings.Builder
var collected bool
var firstErr error
saveOutput, err := runCommand(saveBin)
switch {
case err != nil:
firstErr = err
log.Warnf("Failed to collect %s output: %v", saveBin, err)
case strings.TrimSpace(saveOutput) == "":
log.Debugf("%s produced no output, skipping", saveBin)
default:
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
builder.WriteString(saveOutput)
builder.WriteString("\n")
collected = true
}
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
builder.WriteString(listHeader)
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
stats, err := getTableStatistics(table)
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
if firstErr == nil {
firstErr = err
}
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
continue
}
builder.WriteString(fmt.Sprintf("*%s\n", table))
builder.WriteString(stats)
builder.WriteString("\n")
collected = true
}
if !collected {
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
}
return builder.String(), nil
}
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
func runCommand(name string, args ...string) (string, error) {
cmd := exec.Command(name, args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
}
rules := stdout.String()
if strings.TrimSpace(rules) == "" {
return "", fmt.Errorf("no iptables rules found")
}
return rules, nil
}
// getTableStatistics gets verbose statistics for an entire table using iptables command
func getTableStatistics(table string) (string, error) {
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
}
return stdout.String(), nil
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
return fmt.Sprintf("type-%v", keyType)
}
}
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
func (g *BundleGenerator) addSysctls() error {
log.Info("Collecting sysctls")
content := collectSysctls()
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
return fmt.Errorf("add sysctls to bundle: %w", err)
}
return nil
}
// collectSysctls reads every sysctl that the netbird client may modify, plus
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
func collectSysctls() string {
var builder strings.Builder
writeSysctlGroup(&builder, "forwarding", []string{
"net.ipv4.ip_forward",
"net.ipv6.conf.all.forwarding",
"net.ipv6.conf.default.forwarding",
})
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
writeSysctlGroup(&builder, "rp_filter", append(
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
listInterfaceSysctls("ipv4", "rp_filter")...,
))
writeSysctlGroup(&builder, "src_valid_mark", append(
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
listInterfaceSysctls("ipv4", "src_valid_mark")...,
))
writeSysctlGroup(&builder, "conntrack", []string{
"net.netfilter.nf_conntrack_acct",
"net.netfilter.nf_conntrack_tcp_loose",
})
writeSysctlGroup(&builder, "tcp", []string{
"net.ipv4.tcp_tw_reuse",
})
return builder.String()
}
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
for _, key := range keys {
value, err := readSysctl(key)
if err != nil {
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
continue
}
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
}
builder.WriteString("\n")
}
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
// (callers add those explicitly so they appear first).
func listInterfaceSysctls(family, leaf string) []string {
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var keys []string
for _, e := range entries {
name := e.Name()
if name == "all" || name == "default" {
continue
}
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
}
sort.Strings(keys)
return keys
}
func readSysctl(key string) (string, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
value, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(value)), nil
}

View File

@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
// IP rules are only supported on Linux
return nil
}
func (g *BundleGenerator) addSysctls() error {
// Sysctl collection is only supported on Linux
return nil
}

View File

@@ -16,6 +16,10 @@ type hostManager interface {
restoreHostDNS() error
supportCustomPort() bool
string() string
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
// hardcoded Quad9 on iOS, nil for noop / mock.
getOriginalNameservers() []netip.Addr
}
type SystemDNSSettings struct {
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
func (n noopHostConfigurator) string() string {
return "noop"
}
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}

View File

@@ -1,14 +1,20 @@
package dns
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// androidHostManager is a noop on the OS side (Android's VPN service handles
// DNS for us) but tracks the OS-reported resolver list pushed via
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
type androidHostManager struct {
holder *hostsDNSHolder
}
func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
return &androidHostManager{holder: holder}, nil
}
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
func (a androidHostManager) string() string {
return "none"
}
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
hosts := a.holder.get()
out := make([]netip.Addr, 0, len(hosts))
for ap := range hosts {
out = append(out, ap.Addr())
}
return out
}

View File

@@ -3,6 +3,7 @@ package dns
import (
"encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
}, nil
}
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
return []netip.Addr{
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
}
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strings"
"syscall"
"time"
@@ -44,9 +45,11 @@ const (
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
@@ -67,10 +70,11 @@ const (
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
nrptEntryCount int
guid string
routingAll bool
gpo bool
nrptEntryCount int
origNameservers []netip.Addr
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
gpo: useGPO,
}
origNameservers, err := configurator.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
}
configurator.origNameservers = origNameservers
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
return configurator, nil
}
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
// registry key except the WG adapter. v4 and v6 servers live in separate
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
var merr *multierror.Error
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
addrs, err := r.captureFromTcpipRoot(root)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
continue
}
for _, addr := range addrs {
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
if err != nil {
return nil, fmt.Errorf("open key: %w", err)
}
defer closer(root)
guids, err := root.ReadSubKeyNames(-1)
if err != nil {
return nil, fmt.Errorf("read subkeys: %w", err)
}
var out []netip.Addr
for _, guid := range guids {
if strings.EqualFold(guid, r.guid) {
continue
}
out = append(out, readInterfaceNameservers(rootPath, guid)...)
}
return out, nil
}
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
keyPath := rootPath + "\\" + guid
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return nil
}
defer closer(k)
// Static NameServer wins over DhcpNameServer for actual resolution.
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
raw, _, err := k.GetStringValue(name)
if err != nil || raw == "" {
continue
}
if out := parseRegistryNameservers(raw); len(out) > 0 {
return out
}
}
return nil
}
func parseRegistryNameservers(raw string) []netip.Addr {
var out []netip.Addr
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
addr, err := netip.ParseAddr(strings.TrimSpace(field))
if err != nil {
continue
}
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// Drop unzoned link-local: not routable without a scope id. If
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
continue
}
out = append(out, addr)
}
return out
}
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(r.origNameservers)
}
func (r *registryConfigurator) supportCustomPort() bool {
return false
}

View File

@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Unlock()
}
//nolint:unused
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList

View File

@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{

View File

@@ -9,6 +9,7 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// SetRouteSources mock implementation of SetRouteSources from Server interface
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
// Mock implementation - no-op
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
@@ -32,6 +33,15 @@ const (
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
networkManagerDbusIPv4Key = "ipv4"
networkManagerDbusIPv6Key = "ipv6"
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
}
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
origNameservers []netip.Addr
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
return &networkManagerDbusConfigurator{
c := &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from NetworkManager: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads DNS servers from every NM device's
// IP4Config / IP6Config except our WG device.
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
devices, err := networkManagerListDevices()
if err != nil {
return nil, fmt.Errorf("list devices: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, dev := range devices {
if dev == n.dbusLinkObject {
continue
}
ifaceName := readNetworkManagerDeviceInterface(dev)
for _, addr := range readNetworkManagerDeviceDNS(dev) {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// IP6Config.Nameservers is a byte slice without zone info;
// reattach the device's interface name so a captured fe80::…
// stays routable.
if addr.IsLinkLocalUnicast() && ifaceName != "" {
addr = addr.WithZone(ifaceName)
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return ""
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
if err != nil {
return ""
}
s, _ := v.Value().(string)
return s
}
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
}
defer closeConn()
var devs []dbus.ObjectPath
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
return nil, err
}
return devs, nil
}
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return nil
}
defer closeConn()
var out []netip.Addr
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
out = append(out, readIPv4ConfigDNS(path)...)
}
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
out = append(out, readIPv6ConfigDNS(path)...)
}
return out
}
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
v, err := obj.GetProperty(property)
if err != nil {
return ""
}
path, ok := v.Value().(dbus.ObjectPath)
if !ok || path == "/" {
return ""
}
return path
}
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
// legacy uint32 Nameservers property.
if out := readIPv4NameserverData(obj); len(out) > 0 {
return out
}
return readIPv4LegacyNameservers(obj)
}
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([]map[string]dbus.Variant)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
addrVar, ok := entry["address"]
if !ok {
continue
}
s, ok := addrVar.Value().(string)
if !ok {
continue
}
if a, err := netip.ParseAddr(s); err == nil {
out = append(out, a)
}
}
return out
}
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([]uint32)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, n := range raw {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], n)
out = append(out, netip.AddrFrom4(b))
}
return out
}
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([][]byte)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, b := range raw {
if a, ok := netip.AddrFromSlice(b); ok {
out = append(out, a)
}
}
return out
}
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(n.origNameservers)
}
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager()
return newHostManager(s.hostsDNSHolder)
}

View File

@@ -6,7 +6,7 @@ import (
"net"
"net/netip"
"os"
"strings"
"runtime"
"testing"
"time"
@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -31,8 +32,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -101,16 +104,17 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
}
}
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
statusRecorder: peer.NewRecorder("mgm"),
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil, 0)
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
func TestDNSPermanent_matchOnly(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
}
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
// + androidHostManager). On non-android the desktop host manager replaces it
// during Initialize and the assertion stops making sense. Skipped here until we
// have an android CI runner.
func skipUnlessAndroid(t *testing.T) {
t.Helper()
if runtime.GOOS != "android" {
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
@@ -1065,7 +1017,6 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
// admin-defined nameserver groups targeting the same domain collapse into a
// single handler with each group preserved as a sequential inner list.
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
}
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
require.NoError(t, err)
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
assert.Equal(t, "example.com", muxUpdates[0].domain)
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
handler := muxUpdates[0].handler.(*upstreamResolver)
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
assert.Equal(t, upstreamRace{
netip.MustParseAddrPort("192.0.2.2:53"),
netip.MustParseAddrPort("192.0.2.3:53"),
}, handler.upstreamServers[1])
}
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
// (overlay route selected-but-no-active-peer) is intentionally NOT an
// input to the evaluator anymore: the verdict drives the Enabled flag,
// which must always reflect what we actually observed. Gate-aware event
// suppression is tested separately in the projection test.
//
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
// fresh-working → Unhealthy; otherwise Undecided.
func TestEvaluateNSGroupHealth(t *testing.T) {
now := time.Now()
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
okThenFail := UpstreamHealth{
LastOk: now.Add(-10 * time.Second),
LastFail: now.Add(-1 * time.Second),
LastErr: "timeout",
}
failThenOk := UpstreamHealth{
LastOk: now.Add(-1 * time.Second),
LastFail: now.Add(-10 * time.Second),
LastErr: "timeout",
}
tests := []struct {
name string
health map[netip.AddrPort]UpstreamHealth
servers []netip.AddrPort
wantVerdict nsGroupVerdict
wantErrSubst string
}{
{
name: "no record, undecided",
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "fresh success, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "fresh failure, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "only stale success, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "only stale failure, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "both fresh, fail newer, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "both fresh, ok newer, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one success wins",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
b: recentOk,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one fail one unseen, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "two upstreams, all recent failures, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "SERVFAIL",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
if tc.wantErrSubst != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErrSubst)
} else {
assert.NoError(t, err)
}
})
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
health map[netip.AddrPort]UpstreamHealth
}
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (h *healthStubHandler) Stop() {}
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
return h.health
}
// TestProjection_SteadyStateIsSilent guards against duplicate events:
// while a group stays Unhealthy tick after tick, only the first
// Unhealthy transition may emit. Same for staying Healthy.
func TestProjection_SteadyStateIsSilent(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "first fail emits warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.tick()
fx.expectNoEvent("staying unhealthy must not re-emit")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery on transition")
fx.tick()
fx.tick()
fx.expectNoEvent("staying healthy must not re-emit")
}
// projTestFixture is the common setup for the projection tests: a
// single-upstream group whose route classification the test can flip by
// assigning to selected/active. Callers drive failures/successes by
// mutating stub.health and calling refreshHealth.
type projTestFixture struct {
t *testing.T
recorder *peer.Status
events <-chan *proto.SystemEvent
server *DefaultServer
stub *healthStubHandler
group *nbdns.NameServerGroup
srv netip.AddrPort
selected route.HAMap
active route.HAMap
}
func newProjTestFixture(t *testing.T) *projTestFixture {
t.Helper()
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
srv := netip.MustParseAddrPort("100.64.0.1:53")
fx := &projTestFixture{
t: t,
recorder: recorder,
events: sub.Events(),
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
srv: srv,
group: &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
},
}
fx.server = &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
fx.server.mux.Unlock()
return fx
}
func (f *projTestFixture) setHealth(h UpstreamHealth) {
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
}
func (f *projTestFixture) tick() []peer.NSGroupState {
f.server.refreshHealth()
return f.recorder.GetDNSStates()
}
func (f *projTestFixture) expectNoEvent(why string) {
f.t.Helper()
select {
case evt := <-f.events:
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
case <-time.After(100 * time.Millisecond):
}
}
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
f.t.Helper()
select {
case evt := <-f.events:
assert.Contains(f.t, evt.Message, substr, why)
return evt
case <-time.After(time.Second):
f.t.Fatalf("expected event (%s) with %q", why, substr)
return nil
}
}
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
// that is not inside any selected route (public DNS) fires the warning
// on the first Unhealthy tick, no grace period.
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "public DNS failure")
}
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
// the upstream is inside a selected route AND the route has a Connected
// peer. Tunnel is up, failure is real, emit immediately.
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "overlay + connected failure")
}
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
// First tick: Unhealthy display, no warning. After the grace window
// elapses with no recovery, the warning fires.
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
grace := 50 * time.Millisecond
fx := newProjTestFixture(t)
fx.server.warningDelayBase = grace
fx.selected = overlayMapForTest
// active stays nil: routed but not connected.
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
fx.expectNoEvent("first fail tick within grace window")
time.Sleep(grace + 10*time.Millisecond)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "warning after grace window")
}
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
// whose address is inside the WireGuard overlay range but is not
// covered by any selected route (peer-to-peer DNS without an explicit
// route). Until a peer reports Connected for that address, startup
// failures must be held just like the routed case.
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-sub.Events():
t.Fatalf("unexpected event during grace window: %+v", evt)
case <-time.After(100 * time.Millisecond):
}
time.Sleep(60 * time.Millisecond)
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
server.refreshHealth()
select {
case evt := <-sub.Events():
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected warning after grace window")
}
}
// TestProjection_StopClearsHealthState verifies that Stop wipes the
// per-group projection state so a subsequent Start doesn't inherit
// sticky flags (notably everHealthy) that would bypass the grace
// window during the next peer handshake.
func TestProjection_StopClearsHealthState(t *testing.T) {
wgIface := &mocWGIface{}
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgIface,
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: defaultWarningDelayBase,
currentConfigHash: ^uint64(0),
}
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
srv := netip.MustParseAddrPort("8.8.8.8:53")
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
server.healthProjectMu.Lock()
p, ok := server.nsGroupProj[generateGroupKey(group)]
server.healthProjectMu.Unlock()
require.True(t, ok, "projection state should exist after tick")
require.True(t, p.everHealthy, "tick with success must set everHealthy")
server.Stop()
server.healthProjectMu.Lock()
cleared := server.nsGroupProj == nil
server.healthProjectMu.Unlock()
assert.True(t, cleared, "Stop must clear nsGroupProj")
}
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
fx.selected = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectNoEvent("fail within grace, warning suppressed")
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("recovery without prior warning must not emit")
}
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
// whole design leans on: recovery events only appear when a warning
// event was actually emitted for the current streak. A Healthy verdict
// without a prior warning is silent, so the user never sees "recovered"
// out of thin air.
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("first healthy tick should not recover anything")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "public fail emits immediately")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery follows real warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "second cycle warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "second cycle recovery")
}
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
// has ever been Healthy, subsequent failures skip the grace window even
// if classification says "routed + not connected". The system has
// proved it can work, so any new failure is real.
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
fx := newProjTestFixture(t)
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
fx.server.warningDelayBase = time.Hour
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
// Establish "ever healthy".
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectNoEvent("first healthy tick")
// Peer drops. Query fails. Routed + not connected → normally grace,
// but everHealthy flag bypasses it.
fx.active = nil
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
}
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
// from the design discussion: once a group has been healthy, a brief
// reconnect that produces a failing tick will fire warning + recovery.
// This is by design: user-visible blips are accurate signal, not noise.
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "blip warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "blip recovery")
}
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
// rule: a group with at least one public upstream is in the "immediate"
// category regardless of the other upstreams' routing, because the
// public one has no peer-startup excuse. Prevents public-DNS failures
// from being hidden behind a routed sibling.
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
events := sub.Events()
public := netip.MustParseAddrPort("8.8.8.8:53")
overlay := netip.MustParseAddrPort("100.64.0.1:53")
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
},
}
stub := &healthStubHandler{
health: map[netip.AddrPort]UpstreamHealth{
public: {LastFail: time.Now(), LastErr: "servfail"},
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-events:
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected immediate warning because group contains a public upstream")
}
}
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
flat := handler.flatUpstreams()
assert.Len(t, flat, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
if upstream.Addr() == expected {
found = true
break

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
"github.com/godbus/dbus/v5"
@@ -40,10 +41,17 @@ const (
)
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
ifaceName string
dbusLinkObject dbus.ObjectPath
ifaceName string
wgIndex int
origNameservers []netip.Addr
}
const (
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
)
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
return &systemdDbusConfigurator{
c := &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
wgIndex: iface.Index,
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
// every default-route link except our own WG link. Non-default-route links
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
// actually serve host queries.
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("list interfaces: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, iface := range ifaces {
if !s.isCandidateLink(iface) {
continue
}
linkPath, err := getSystemdLinkPath(iface.Index)
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
continue
}
for _, addr := range readSystemdLinkDNS(linkPath) {
addr = normalizeSystemdAddr(addr, iface.Name)
if !addr.IsValid() {
continue
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
if iface.Index == s.wgIndex {
return false
}
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
return false
}
return true
}
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
// Returns the zero Addr to signal "skip this entry".
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
return netip.Addr{}
}
if addr.IsLinkLocalUnicast() {
return addr.WithZone(ifaceName)
}
return addr
}
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return "", fmt.Errorf("dbus resolve1: %w", err)
}
defer closeConn()
var p string
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
return "", err
}
return dbus.ObjectPath(p), nil
}
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return false
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
if err != nil {
return false
}
b, ok := v.Value().(bool)
return ok && b
}
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([][]any)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
if len(entry) < 2 {
continue
}
raw, ok := entry[1].([]byte)
if !ok {
continue
}
addr, ok := netip.AddrFromSlice(raw)
if !ok {
continue
}
out = append(out, addr)
}
return out
}
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
func (s *systemdDbusConfigurator) supportCustomPort() bool {

View File

@@ -1,3 +1,32 @@
// Package dns implements the client-side DNS stack: listener/service on the
// peer's tunnel address, handler chain that routes questions by domain and
// priority, and upstream resolvers that forward what remains to configured
// nameservers.
//
// # Upstream resolution and the race model
//
// When two or more nameserver groups target the same domain, DefaultServer
// merges them into one upstream handler whose state is:
//
// upstreamResolverBase
// └── upstreamServers []upstreamRace // one entry per source NS group
// └── []netip.AddrPort // primary, fallback, ...
//
// Each source nameserver group contributes one upstreamRace. Within a race
// upstreams are tried in order: the next is used only on failure (timeout,
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
// the walk. When more than one race exists, ServeDNS fans out one
// goroutine per race and returns the first valid answer, cancelling the
// rest. A handler with a single race skips the fan-out.
//
// # Health projection
//
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
// periodically merges these snapshots across handlers and projects them
// into peer.NSGroupState. There is no active probing: a group is marked
// unhealthy only when every seen upstream has a recent failure and none
// has a recent success. Healthy→unhealthy fires a single
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
package dns
import (
@@ -11,11 +40,8 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -25,11 +51,33 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var currentMTU uint16 = iface.DefaultMTU
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
// from one upstream means another upstream would return the same answer:
// DNSSEC validation outcomes and policy-based blocks. Transient errors
// (network, cached, not ready) are not included.
var nonRetryableEDECodes = map[uint16]struct{}{
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
dns.ExtendedErrorCodeDNSBogus: {},
dns.ExtendedErrorCodeSignatureExpired: {},
dns.ExtendedErrorCodeSignatureNotYetValid: {},
dns.ExtendedErrorCodeDNSKEYMissing: {},
dns.ExtendedErrorCodeRRSIGsMissing: {},
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
dns.ExtendedErrorCodeNSECMissing: {},
dns.ExtendedErrorCodeBlocked: {},
dns.ExtendedErrorCodeCensored: {},
dns.ExtendedErrorCodeFiltered: {},
dns.ExtendedErrorCodeProhibited: {},
}
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
type privateClientIface interface {
Name() string
@@ -46,15 +94,17 @@ const (
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
// within one race, so a slow primary can't eat the whole race budget.
raceMaxTotalTimeout = 5 * time.Second
// raceMinPerUpstreamTimeout is the floor applied when dividing
// raceMaxTotalTimeout across upstreams within a race.
raceMinPerUpstreamTimeout = 2 * time.Second
)
const (
protoUDP = "udp"
@@ -63,6 +113,69 @@ const (
type dnsProtocolKey struct{}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
// upstreamRace is an ordered list of upstreams derived from one configured
// nameserver group. Order matters: the first upstream is tried first, the
// second only on failure, and so on. Multiple upstreamRace values coexist
// inside one resolver when overlapping nameserver groups target the same
// domain; those races run in parallel and the first valid answer wins.
type upstreamRace []netip.AddrPort
// UpstreamHealth is the last query-path outcome for a single upstream,
// consumed by nameserver-group status projection.
type UpstreamHealth struct {
LastOk time.Time
LastFail time.Time
LastErr string
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []upstreamRace
domain domain.Domain
upstreamTimeout time.Duration
healthMu sync.RWMutex
health map[netip.AddrPort]*UpstreamHealth
statusRecorder *peer.Status
// selectedRoutes returns the current set of client routes the admin
// has enabled. Called lazily from the query hot path when an upstream
// might need a tunnel-bound client (iOS) and from health projection.
selectedRoutes func() route.HAMap
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
type raceResult struct {
msg *dns.Msg
upstream netip.AddrPort
protocol string
ede string
failures []upstreamFailure
}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
@@ -79,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
@@ -103,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []netip.AddrPort
domain string
disabled bool
successCount atomic.Int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
statusRecorder: statusRecorder,
ctx: ctx,
cancel: cancel,
domain: d,
upstreamTimeout: UpstreamTimeout,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("Upstream %s", u.upstreamServers)
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
}
// ID returns the unique handler ID
// ID returns the unique handler ID. Race groupings and within-race
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
// the same servers but with different semantics (serial fallback vs
// parallel race), so their handlers must not collide.
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
hash.Write([]byte(u.domain.PunycodeString() + ":"))
for _, race := range u.upstreamServers {
hash.Write([]byte("["))
for _, s := range race {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
hash.Write([]byte("]"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@@ -173,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
}
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
u.cancel()
}
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
// flatUpstreams is for logging and ID hashing only, not for dispatch.
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
var out []netip.AddrPort
for _, g := range u.upstreamServers {
out = append(out, g...)
}
return out
}
// setSelectedRoutes swaps the accessor used to classify overlay-routed
// upstreams. Called when route sources are wired after the handler was
// built (permanent / iOS constructors).
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
u.selectedRoutes = selected
}
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
if len(servers) == 0 {
return
}
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
}
// ServeDNS handles a DNS request
@@ -221,59 +314,226 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
groups := u.upstreamServers
switch len(groups) {
case 0:
return false, nil
case 1:
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
default:
return u.raceAll(ctx, w, r, groups, logger)
}
}
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
res := u.tryRace(ctx, r, group)
if res.msg == nil {
return false, res.failures
}
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, res.failures
}
// raceAll runs one worker per group in parallel, taking the first valid
// answer and cancelling the rest.
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
raceCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Buffer sized to len(groups) so workers never block on send, even
// after the coordinator has returned.
results := make(chan raceResult, len(groups))
for _, g := range groups {
// tryRace clones the request per attempt, so workers never share
// a *dns.Msg and concurrent EDNS0 mutations can't race.
go func(g upstreamRace) {
results <- u.tryRace(raceCtx, r, g)
}(g)
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
for range groups {
select {
case res := <-results:
failures = append(failures, res.failures...)
if res.msg != nil {
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, failures
}
case <-ctx.Done():
return false, failures
}
}
return false, failures
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
timeout := u.upstreamTimeout
if len(group) > 1 {
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
// still honor raceMinPerUpstreamTimeout as a floor for correctness
// on slow links, but the outer context ensures the combined walk
// cannot exceed the cap regardless of group size.
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
}
var failures []upstreamFailure
for _, upstream := range group {
if ctx.Err() != nil {
return raceResult{failures: failures}
}
// Clone the request per attempt: the exchange path mutates EDNS0
// options in-place, so reusing the same *dns.Msg across sequential
// upstreams would carry those mutations (e.g. a reduced UDP size)
// into the next attempt.
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
if failure != nil {
failures = append(failures, *failure)
continue
}
res.failures = failures
return res
}
return raceResult{failures: failures}
}
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// The caller already passed a per-attempt copy, so we can mutate r
// directly; hadEdns reflects the original client request's state and
// controls whether we strip the OPT from the response.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(upstreamUDPSize(), false)
}
startTime := time.Now()
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
if err != nil {
return u.handleUpstreamError(err, upstream, startTime)
// A parent cancellation (e.g., another race won and the coordinator
// cancelled the losers) is not an upstream failure. Check both the
// error chain and the parent context: a transport may surface the
// cancellation as a read/deadline error rather than context.Canceled.
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
}
failure := u.handleUpstreamError(err, upstream, startTime)
u.markUpstreamFail(upstream, failure.reason)
return raceResult{}, failure
}
if rm == nil || !rm.Response {
return &upstreamFailure{upstream: upstream, reason: "no response"}
u.markUpstreamFail(upstream, "no response")
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}
// healthEntry returns the mutable health record for addr, lazily creating
// the map and the entry. Caller must hold u.healthMu.
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
if u.health == nil {
u.health = make(map[netip.AddrPort]*UpstreamHealth)
}
h := u.health[addr]
if h == nil {
h = &UpstreamHealth{}
u.health[addr] = h
}
return h
}
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastOk = time.Now()
h.LastFail = time.Time{}
h.LastErr = ""
}
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastFail = time.Now()
h.LastErr = reason
}
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
u.healthMu.RLock()
defer u.healthMu.RUnlock()
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
for k, v := range u.health {
out[k] = *v
}
return out
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
@@ -289,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
if proto != "" {
resutil.SetMeta(w, "upstream_protocol", proto)
}
// Clear Zero bit from external responses to prevent upstream servers from
@@ -303,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
totalUpstreams := len(u.flatUpstreams())
failedCount := len(failures)
failureSummary := formatFailures(failures)
@@ -337,117 +605,32 @@ func formatFailures(failures []upstreamFailure) string {
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
// nonRetryableEDE returns the first non-retryable EDE code carried in the
// response, if any.
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
opt := rm.IsEdns0()
if opt == nil {
return 0, false
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
wg.Add(1)
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
success = true
mu.Unlock()
}(upstream)
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
if u.statusRecorder == nil {
return
for _, o := range opt.Option {
ede, ok := o.(*dns.EDNS0_EDE)
if !ok {
continue
}
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
return ede.InfoCode, true
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
)
}
return 0, false
}
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolverBase) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.1,
MaxInterval: u.reactivatePeriod,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
// edeName returns a human-readable name for an EDE code, falling back to
// the numeric code when unknown.
func edeName(code uint16) string {
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
return name
}
operation := func() error {
select {
case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default:
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
return nil
}
}
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
return fmt.Sprintf("EDE %d", code)
}
// isTimeout returns true if the given error is a network timeout error.
@@ -461,45 +644,6 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
if u.disabled {
return
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
}
func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
@@ -511,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -537,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
rm, t, err := client.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
@@ -562,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -584,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
// so net.Dialer doesn't reject the TCP dial with "mismatched local
// address type".
func toTCPClient(c *dns.Client) *dns.Client {
tcp := *c
tcp.Net = protoTCP
if tcp.Dialer == nil {
return &tcp
}
d := *tcp.Dialer
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
}
tcp.Dialer = &d
return &tcp
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
@@ -725,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
// haMapRouteCount returns the total number of routes across all HA
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
// routes per key, so len(hm) is the number of HA groups, not routes.
func haMapRouteCount(hm route.HAMap) int {
total := 0
for _, routes := range hm {
total += len(routes)
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
return total
}
// haMapContains checks whether ip is covered by any concrete prefix in
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
// routes carry a placeholder Network that can't be prefix-checked, so we
// can't know at this point whether ip is reached through one. Callers
// decide how to interpret the unknown: health projection treats it as
// "possibly routed" to avoid emitting false-positive warnings during
// startup, while iOS dial selection requires a concrete match before
// binding to the tunnel.
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
for _, routes := range hm {
for _, r := range routes {
if r.IsDynamic() {
haveDynamic = true
continue
}
if r.Network.Contains(ip) {
return true, haveDynamic
}
}
}
return false, haveDynamic
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -26,9 +27,9 @@ func newUpstreamResolver(
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,

View File

@@ -12,6 +12,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -24,9 +25,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolverIOS struct {
@@ -27,9 +28,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP = upstreamIP.Unmap()
}
addr := u.wgIface.Address()
var routed bool
if u.selectedRoutes != nil {
// Only a concrete prefix match binds to the tunnel: dialing
// through a private client for an upstream we can't prove is
// routed would break public resolvers.
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
}
needsPrivate := addr.Network.Contains(upstreamIP) ||
addr.IPv6Net.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
routed
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return ExchangeWithFallback(nil, client, r, upstream)
return ExchangeWithFallback(ctx, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.addRace(servers)
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// exchange mock implementation of exchange from upstreamResolver
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
msg *dns.Msg
err error
delay time.Duration
}
type mockUpstreamResolverPerServer struct {
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
r, ok := c.responses[upstream]
if !ok {
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{
ctx: context.TODO(),
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: time.Microsecond * 100,
}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
failed := false
resolver.deactivate = func(error) {
failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be Disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
if r.delay > 0 {
select {
case <-time.After(r.delay):
case <-ctx.Done():
return nil, c.rtt, ctx.Err()
}
}
return r.msg, c.rtt, r.err
}
func TestUpstreamResolver_Failover(t *testing.T) {
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
// configured for the same domain, with one broken group. The merge+race
// path should answer as fast as the working group and not pay the timeout
// of the broken one on every query.
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
broken := netip.MustParseAddrPort("192.0.2.1:53")
working := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
// Force the broken upstream to only unblock via timeout /
// cancellation so the assertion below can't pass if races
// were run serially.
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: 250 * time.Millisecond,
}
resolver.addRace([]netip.AddrPort{broken})
resolver.addRace([]netip.AddrPort{working})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
start := time.Now()
resolver.ServeDNS(responseWriter, inputMSG)
elapsed := time.Since(start)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
require.NotEmpty(t, responseMSG.Answer)
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
// Working group answers in a single RTT; the broken group's
// timeout (100ms) must not block the response.
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
}
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
// resolver returns SERVFAIL rather than leaking a partial response.
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a})
resolver.addRace([]netip.AddrPort{b})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
require.NotNil(t, responseMSG)
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
}
// TestUpstreamResolver_HealthTracking verifies that query-path results are
// recorded into per-upstream health, which is what projects back to
// NSGroupState for status reporting.
func TestUpstreamResolver_HealthTracking(t *testing.T) {
ok := netip.MustParseAddrPort("192.0.2.10:53")
bad := netip.MustParseAddrPort("192.0.2.11:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{ok, bad})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, ok)
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
assert.Empty(t, health[ok].LastErr)
// bad upstream was never tried because ok answered first; its health
// should remain unset.
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
var receivedUDPSize atomic.Uint32
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
receivedUDPSize.Store(uint32(opt.UDPSize()))
}
m := new(dns.Msg)
m.SetReply(r)
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
"upstream should see capped EDNS0, not the client's 4096")
}
@@ -770,3 +847,132 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
var queried []string
tracking := &trackingMockClient{
inner: &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): {msg: servfailWithEDE},
upstream2.String(): {msg: successResp},
},
rtt: time.Millisecond,
},
queriedUpstreams: &queried,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []upstreamRace{{upstream1, upstream2}},
upstreamTimeout: UpstreamTimeout,
}
var written *dns.Msg
w := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
written = m
return nil
},
}
// Client query without EDNS0 must not see an OPT in the response.
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(w, q)
require.NotNil(t, written, "response must be written")
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
assert.Len(t, queried, 1, "only first upstream should be queried")
assert.Equal(t, upstream1.String(), queried[0])
for _, rr := range written.Extra {
_, isOPT := rr.(*dns.OPT)
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
}
}

View File

@@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
@@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
return nil
}
@@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:

View File

@@ -53,6 +53,7 @@ type Manager interface {
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetActiveClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetActiveClientRoutes returns the subset of selected client routes
// that are currently reachable: the route's peer is Connected and is
// the one actively carrying the route (not just an HA sibling).
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
m.mux.Lock()
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
recorder := m.statusRecorder
m.mux.Unlock()
if recorder == nil {
return selected
}
out := make(route.HAMap, len(selected))
for id, routes := range selected {
for _, r := range routes {
st, err := recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if st.ConnStatus != peer.StatusConnected {
continue
}
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
continue
}
out[id] = routes
break
}
}
return out
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
if len(routes) == 0 {
return false
}
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {

View File

@@ -19,6 +19,7 @@ type MockManager struct {
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetActiveClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
return nil
}
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
if m.GetActiveClientRoutesFunc != nil {
return m.GetActiveClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,10 +13,6 @@ import (
"github.com/netbirdio/netbird/route"
)
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct {
mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{}
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
return isSelected
return rs.isSelectedLocked(routeID)
}
// FilterSelected removes unselected routes from the provided map.
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{}
for id, rt := range routes {
netID := id.NetID()
_, deselected := rs.deselectedRoutes[netID]
if !deselected {
if !rs.isDeselectedLocked(id.NetID()) {
filtered[id] = rt
}
}
return filtered
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
if rs.isDeselectedLocked(netID) {
continue
}
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {

View File

@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-default-route entry named "corp-v6" is not an exit node and
// must not be skipped because its v4 base "corp" is deselected.
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
routes := route.HAMap{
"corp-v6|10.0.0.0/8": {corpV6},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
})
}
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
// SkipAutoApply flag governs each untouched route independently, even when
// the user has explicit selections on other routes.
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
autoApplied := &route.Route{
NetID: "Auto",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: false,
}
skipApply := &route.Route{
NetID: "Skip",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"Auto|0.0.0.0/0": {autoApplied},
"Skip|0.0.0.0/0": {skipApply},
}
rs := routeselector.NewRouteSelector()
// User makes an unrelated explicit selection elsewhere.
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
}
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
// as exit nodes by the selector's filter path.
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
v6Exit := &route.Route{
NetID: "V6Only",
Network: netip.MustParsePrefix("::/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"V6Only|::/0": {v6Exit},
}
rs := routeselector.NewRouteSelector()
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
}
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}

View File

@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -64,6 +64,13 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -76,10 +83,28 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

View File

@@ -0,0 +1,93 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func TestPersistLoginOverrides(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
initialMgmtURL string
initialPSK string
newMgmtURL string
newPSK *string
wantMgmtURL string
wantPSK string
}{
{
name: "persist new management URL",
initialMgmtURL: "https://old.example.com:33073",
newMgmtURL: "https://new.example.com:33073",
wantMgmtURL: "https://new.example.com:33073",
},
{
name: "persist new pre-shared key",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "old-key",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "new-key",
},
{
name: "persist both",
initialMgmtURL: "https://old.example.com:33073",
initialPSK: "old-key",
newMgmtURL: "https://new.example.com:33073",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://new.example.com:33073",
wantPSK: "new-key",
},
{
name: "no inputs preserves existing",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
{
name: "empty PSK pointer is ignored",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
newPSK: strPtr(""),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origDefault := profilemanager.DefaultConfigPath
t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault })
dir := t.TempDir()
profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json")
seed := profilemanager.ConfigInput{
ConfigPath: profilemanager.DefaultConfigPath,
ManagementURL: tt.initialMgmtURL,
}
if tt.initialPSK != "" {
seed.PreSharedKey = strPtr(tt.initialPSK)
}
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")
cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath)
require.NoError(t, err, "read back config")
require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL")
require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key")
})
}
}

View File

@@ -490,6 +490,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.mutex.Unlock()
if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil {
log.Errorf("failed to persist login overrides: %v", err)
return nil, fmt.Errorf("persist login overrides: %w", err)
}
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
@@ -964,7 +969,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return &proto.LogoutResponse{}, nil
}
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
@@ -1766,3 +1771,29 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the
// active profile config so that subsequent reads pick them up. Empty/nil values are ignored.
func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error {
if preSharedKey != nil && *preSharedKey == "" {
preSharedKey = nil
}
if managementURL == "" && preSharedKey == nil {
return nil
}
cfgPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("active profile file path: %w", err)
}
input := profilemanager.ConfigInput{
ConfigPath: cfgPath,
ManagementURL: managementURL,
PreSharedKey: preSharedKey,
}
if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil {
return fmt.Errorf("update config: %w", err)
}
return nil
}

View File

@@ -25,6 +25,7 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/netrelay"
)
const (
@@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
continue
}
go c.handleLocalForward(localConn, remoteAddr)
go c.handleLocalForward(ctx, localConn, remoteAddr)
}
}()
@@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
}
// handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local port forwarding: close local connection: %v", err)
@@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
}()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
select {
case <-ctx.Done():
return
case newChan := <-channelRequests:
case newChan, ok := <-channelRequests:
if !ok {
return
}
if newChan != nil {
go c.handleRemoteForwardChannel(newChan, localAddr)
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
}
}
}
}
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept()
if err != nil {
return
@@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
go ssh.DiscardRequests(reqs)
localConn, err := net.Dial("tcp", localAddr)
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
// channel open indefinitely; the relay itself runs under the outer ctx.
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
var dialer net.Dialer
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
cancelDial()
if err != nil {
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
return
}
defer func() {
@@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
}()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
// tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string {
return addresses
}
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -229,18 +229,35 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
sshConfigPathTmp := sshConfigPath + ".tmp"
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp")
if err != nil {
return fmt.Errorf("create temp SSH config: %w", err)
}
tmpPath := tmp.Name()
defer func() {
if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) {
log.Debugf("remove temp SSH config %s: %v", tmpPath, err)
}
}()
if err := tmp.Close(); err != nil {
return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err)
}
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
}
if err := os.Chmod(tmpPath, 0644); err != nil {
return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -17,7 +17,7 @@ import (
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util/netrelay"
)
const privilegedPortThreshold = 1024
@@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
return
}
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
}
// openForwardChannel creates an SSH forwarded-tcpip channel

View File

@@ -8,9 +8,9 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"slices"
"strconv"
"strings"
"sync"
"time"
@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -53,6 +54,10 @@ const (
DefaultJWTMaxTokenAge = 10 * 60
)
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
// the forwarded destination before rejecting the SSH channel.
const directTCPIPDialTimeout = 30 * time.Second
var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found")
@@ -933,5 +938,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s", hostPort)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
}
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
// DirectTCPIPHandler. The upstream handler closes both sides on the first
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
// own terms.
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
dest := net.JoinHostPort(host, strconv.Itoa(port))
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go cryptossh.DiscardRequests(reqs)
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
}

View File

@@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
}
func isDefaultRoute(routeRange string) bool {
return routeRange == "0.0.0.0/0" || routeRange == "::/0"
// routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0",
// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry.
for _, part := range strings.Split(routeRange, ",") {
switch strings.TrimSpace(part) {
case "0.0.0.0/0", "::/0":
return true
}
}
return false
}
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {

View File

@@ -133,13 +133,18 @@ type ManagementConfig struct {
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
MfaSessionMaxLifetime string `yaml:"mfaSessionMaxLifetime"`
MfaSessionIdleTimeout string `yaml:"mfaSessionIdleTimeout"`
MfaSessionRememberMe bool `yaml:"mfaSessionRememberMe"`
SessionCookieEncryptionKey string `yaml:"sessionCookieEncryptionKey"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
DashboardPostLogoutRedirectURIs []string `yaml:"dashboardPostLogoutRedirectURIs"`
}
// AuthStorageConfig contains auth storage settings
@@ -581,10 +586,14 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
MfaSessionMaxLifetime: mgmt.Auth.MfaSessionMaxLifetime,
MfaSessionIdleTimeout: mgmt.Auth.MfaSessionIdleTimeout,
MfaSessionRememberMe: mgmt.Auth.MfaSessionRememberMe,
SessionCookieEncryptionKey: mgmt.Auth.SessionCookieEncryptionKey,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
@@ -592,8 +601,9 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
DashboardPostLogoutRedirectURIs: mgmt.Auth.DashboardPostLogoutRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {

View File

@@ -86,6 +86,13 @@ server:
issuer: "https://example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# MFA session settings (applies when TOTP is enabled for an account)
# mfaSessionMaxLifetime: "24h" # Max duration for an MFA session from creation
# mfaSessionIdleTimeout: "1h" # MFA session expires after this idle period
# mfaSessionRememberMe: false # Pre-check "remember me" on login so the MFA session persists across tabs/restarts
# Optional AES key for encrypting embedded IdP session cookies. Can also be set via NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY.
# Must be 16/24/32 raw bytes or base64-encoded to one of those lengths (for example: openssl rand -hex 16).
# sessionCookieEncryptionKey: ""
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
@@ -93,6 +100,9 @@ server:
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# OAuth2 post-logout redirect URIs for dashboard (RP-initiated logout)
# dashboardPostLogoutRedirectURIs:
# - "https://app.example.com/"
# Optional initial admin user
# owner:
# email: "admin@example.com"

46
go.mod
View File

@@ -14,7 +14,7 @@ require (
github.com/onsi/gomega v1.27.6
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.4
github.com/spf13/cobra v1.10.1
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
@@ -41,11 +41,11 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1
github.com/coreos/go-oidc/v3 v3.18.0
github.com/creack/pty v1.1.24
github.com/crowdsecurity/crowdsec v1.7.7
github.com/crowdsecurity/go-cs-bouncer v0.0.21
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex v2.13.0+incompatible
github.com/dexidp/dex/api/v2 v2.4.0
github.com/ebitengine/purego v0.8.4
github.com/eko/gocache/lib/v4 v4.2.0
@@ -53,9 +53,9 @@ require (
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-jose/go-jose/v4 v4.1.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
@@ -72,7 +72,7 @@ require (
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.59
github.com/miekg/dns v1.1.72
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
@@ -113,7 +113,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
go.opentelemetry.io/otel/metric v1.43.0
go.opentelemetry.io/otel/sdk/metric v1.43.0
go.uber.org/mock v0.5.2
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
@@ -141,7 +141,7 @@ require (
filippo.io/edwards25519 v1.1.1 // indirect
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/Azure/go-ntlmssp v0.1.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
@@ -168,6 +168,7 @@ require (
github.com/aws/smithy-go v1.23.0 // indirect
github.com/beevik/etree v1.6.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -183,6 +184,7 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect
github.com/fyne-io/glfw-js v0.3.0 // indirect
github.com/fyne-io/image v0.1.1 // indirect
@@ -190,7 +192,7 @@ require (
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
github.com/go-ldap/ldap/v3 v3.4.13 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
@@ -206,11 +208,15 @@ require (
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
@@ -218,7 +224,13 @@ require (
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
github.com/hashicorp/go-sockaddr v1.0.7 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/hashicorp/hcl v1.0.1-vault-7 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -238,13 +250,13 @@ require (
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/lib/pq v1.12.3 // indirect
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mattn/go-sqlite3 v1.14.42 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
@@ -265,8 +277,10 @@ require (
github.com/nxadm/tail v1.4.11 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/openbao/openbao/api/v2 v2.5.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
@@ -275,11 +289,13 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/russellhaering/goxmldsig v1.6.0 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.8 // indirect
github.com/shoenig/go-m1cpu v0.2.1 // indirect
@@ -288,11 +304,13 @@ require (
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tinylib/msgp v1.6.3 // indirect
github.com/tklauser/go-sysconf v0.3.15 // indirect
github.com/tklauser/numcpus v0.10.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.mongodb.org/mongo-driver v1.17.9 // indirect
@@ -319,10 +337,12 @@ replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-202
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1
replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

101
go.sum
View File

@@ -5,6 +5,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3R
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
@@ -23,8 +25,8 @@ github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
@@ -91,6 +93,8 @@ github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sL
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -117,8 +121,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
@@ -130,14 +134,10 @@ github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy1
github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA=
github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4=
github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -156,6 +156,8 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM
github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA=
github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0=
github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
@@ -171,6 +173,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs=
github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI=
github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk=
@@ -189,10 +193,10 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ=
github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -229,12 +233,20 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc=
github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU=
github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8=
github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4CyGN+1Q=
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@@ -243,8 +255,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -276,6 +288,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo=
github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc=
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -308,15 +324,29 @@ github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xN
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48=
github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4=
github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw=
github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I=
github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
@@ -387,8 +417,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
@@ -406,9 +436,13 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo=
github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@@ -421,8 +455,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
@@ -451,8 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 h1:4TaYr9O4xX0D2kszeOLclTiCbA3eHq3xWV+9ILJbIYs=
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1/go.mod h1:IHH+H8vK2GfqtIt5u/5OdPh18yk0oDHuj2vz5+Goetg=
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 h1:neE7z+FPUkldl3faK/Jt+hJK2L+1XfQ1W33TQhU9m88=
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
@@ -489,6 +525,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o=
github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -501,6 +539,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
@@ -542,6 +582,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
@@ -565,6 +607,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks=
github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
@@ -587,8 +631,8 @@ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
@@ -628,6 +672,8 @@ github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0
github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o=
github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40=
github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo=
github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s=
github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4=
github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4=
@@ -646,6 +692,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -690,14 +738,15 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@@ -51,6 +51,70 @@ type YAMLConfig struct {
// StaticPasswords cause the server use this list of passwords rather than
// querying the storage.
StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"`
// Sessions holds authentication session configuration.
// Requires DEX_SESSIONS_ENABLED=true feature flag.
Sessions *Sessions `yaml:"sessions" json:"sessions"`
// MFA holds multi-factor authentication configuration.
MFA MFAConfig `yaml:"mfa" json:"mfa"`
}
type Sessions struct {
// CookieName is the name of the session cookie. Defaults to "dex_session".
CookieName string `yaml:"cookieName" json:"cookieName"`
// AbsoluteLifetime is the maximum session lifetime from creation. Defaults to "24h".
AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"`
// ValidIfNotUsedFor is the idle timeout. Defaults to "1h".
ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"`
// RememberMeCheckedByDefault controls the default state of the "remember me" checkbox.
RememberMeCheckedByDefault *bool `yaml:"rememberMeCheckedByDefault" json:"rememberMeCheckedByDefault"`
// CookieEncryptionKey is the AES key for encrypting session cookies.
// Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256.
// If empty, cookies are not encrypted.
CookieEncryptionKey string `yaml:"cookieEncryptionKey" json:"cookieEncryptionKey"`
// SSOSharedWithDefault is the default SSO sharing policy for clients without explicit ssoSharedWith.
// "all" = share with all clients, "none" = share with no one (default: "none").
SSOSharedWithDefault string `yaml:"ssoSharedWithDefault" json:"ssoSharedWithDefault"`
}
type MFAConfig struct {
Authenticators []MFAAuthenticator `yaml:"authenticators" json:"authenticators"`
}
type MFAAuthenticator struct {
ID string `yaml:"id" json:"id"`
Type string `yaml:"type" json:"type"`
Config map[string]interface{} `yaml:"config" json:"config"`
ConnectorTypes []string `yaml:"connectorTypes" json:"connectorTypes"`
}
type TOTPConfig struct {
Issuer string `yaml:"issuer" json:"issuer"`
}
// WebAuthnConfig holds configuration for a WebAuthn authenticator.
type WebAuthnConfig struct {
// RPDisplayName is the human-readable relying party name shown in the browser
// dialog during key registration and authentication (e.g., "My Company SSO").
RPDisplayName string `yaml:"rpDisplayName" json:"rpDisplayName"`
// RPID is the relying party identifier — must match the domain in the browser
// address bar. If empty, derived from the issuer URL hostname.
// Example: "auth.example.com"
RPID string `yaml:"rpID" json:"rpID"`
// RPOrigins is the list of allowed origins for WebAuthn ceremonies.
// If empty, derived from the issuer URL (scheme + host).
// Example: ["https://auth.example.com"]
RPOrigins []string `yaml:"rpOrigins" json:"rpOrigins"`
// AttestationPreference controls what attestation data the authenticator should provide:
// "none" — don't request attestation (simpler, more private)
// "indirect" — authenticator may anonymize attestation (default)
// "direct" — request full attestation (for enterprise key model verification)
AttestationPreference string `yaml:"attestationPreference" json:"attestationPreference"`
// Timeout is the duration allowed for the browser WebAuthn ceremony
// (registration or login). Defaults to "60s".
Timeout string `yaml:"timeout" json:"timeout"`
}
// Web is the config format for the HTTP server.
@@ -116,7 +180,6 @@ type Storage struct {
Config map[string]interface{} `yaml:"config" json:"config"`
}
// Password represents a static user configuration
type Password storage.Password
func (p *Password) UnmarshalYAML(node *yaml.Node) error {
@@ -429,9 +492,98 @@ func (c *YAMLConfig) Validate() error {
if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 {
return fmt.Errorf("cannot specify static passwords without enabling password db")
}
return nil
}
func buildTotpConfig(auth MFAAuthenticator) (*server.TOTPProvider, error) {
data, err := json.Marshal(auth.Config)
if err != nil {
return nil, fmt.Errorf("failed to marshal TOTP config id: %s - %w", auth.ID, err)
}
var cfg TOTPConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse TOTP config id: %s - %w", auth.ID, err)
}
return server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes), nil
}
func buildWebAuthnConfig(auth MFAAuthenticator, issuerURL string) (*server.WebAuthnProvider, error) {
data, err := json.Marshal(auth.Config)
if err != nil {
return nil, fmt.Errorf("failed to marshal WebAuthn config id: %s - %w", auth.ID, err)
}
var cfg WebAuthnConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse WebAuthn config id: %s - %w", auth.ID, err)
}
provider, err := server.NewWebAuthnProvider(cfg.RPDisplayName, cfg.RPID, cfg.RPOrigins,
cfg.AttestationPreference, cfg.Timeout, issuerURL, auth.ConnectorTypes)
if err != nil {
return nil, fmt.Errorf("failed to create WebAuthn provider id: %s - err: %w", auth.ID, err)
}
return provider, nil
}
func buildMFAProviders(authenticators []MFAAuthenticator, issuerURL string, logger *slog.Logger) map[string]server.MFAProvider {
if len(authenticators) == 0 {
return nil
}
providers := make(map[string]server.MFAProvider, len(authenticators))
for _, auth := range authenticators {
switch auth.Type {
case "TOTP":
provider, err := buildTotpConfig(auth)
if err != nil {
logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = provider
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
case "WebAuthn":
provider, err := buildWebAuthnConfig(auth, issuerURL)
if err != nil {
logger.Error("failed to parse WebAuthn config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = provider
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
default:
logger.Error("unknown MFA authenticator type, skipping", "id", auth.ID, "type", auth.Type)
}
}
return providers
}
func buildSessionsConfig(sessions *Sessions) *server.SessionConfig {
if sessions == nil {
return nil
}
if sessions.RememberMeCheckedByDefault == nil {
defaultRememberMeCheckedByDefault := false
sessions.RememberMeCheckedByDefault = &defaultRememberMeCheckedByDefault
}
absoluteLifetime, _ := parseDuration(sessions.AbsoluteLifetime)
validIfNotUsedFor, _ := parseDuration(sessions.ValidIfNotUsedFor)
return &server.SessionConfig{
CookieEncryptionKey: []byte(sessions.CookieEncryptionKey),
CookieName: sessions.CookieName,
AbsoluteLifetime: absoluteLifetime,
ValidIfNotUsedFor: validIfNotUsedFor,
RememberMeCheckedByDefault: *sessions.RememberMeCheckedByDefault,
SSOSharedWithDefault: sessions.SSOSharedWithDefault,
}
}
// ToServerConfig converts YAMLConfig to dex server.Config
func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config {
cfg := server.Config{
@@ -448,6 +600,8 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
Dir: c.Frontend.Dir,
Extra: c.Frontend.Extra,
},
SessionConfig: buildSessionsConfig(c.Sessions),
MFAProviders: buildMFAProviders(c.MFA.Authenticators, c.Issuer, logger),
}
// Use embedded NetBird-styled templates if no custom dir specified
@@ -460,11 +614,6 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
}
// Apply expiry settings
if c.Expiry.SigningKeys != "" {
if d, err := parseDuration(c.Expiry.SigningKeys); err == nil {
cfg.RotateKeysAfter = d
}
}
if c.Expiry.IDTokens != "" {
if d, err := parseDuration(c.Expiry.IDTokens); err == nil {
cfg.IDTokensValidFor = d

View File

@@ -18,6 +18,7 @@ import (
dexapi "github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
@@ -70,7 +71,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Ensure data directory exists
if err := os.MkdirAll(config.DataDir, 0700); err != nil {
if err := os.MkdirAll(config.DataDir, 0o700); err != nil {
return nil, fmt.Errorf("failed to create data directory: %w", err)
}
@@ -101,6 +102,15 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
}
localSignerConfig := signer.LocalConfig{
KeysRotationPeriod: "6h",
}
localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger)
if err != nil {
return nil, fmt.Errorf("failed to create local signer: %w", err)
}
// Build Dex server config - use Dex's types directly
dexConfig := server.Config{
Issuer: issuer,
@@ -110,12 +120,12 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
ContinueOnConnectorFailure: true,
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
RotateKeysAfter: 6 * time.Hour,
IDTokensValidFor: 24 * time.Hour,
RefreshTokenPolicy: refreshPolicy,
Web: server.WebConfig{
Issuer: "NetBird",
},
Signer: localSigner,
}
dexSrv, err := server.NewServer(ctx, dexConfig)
@@ -167,6 +177,14 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
}
localSigner, err := getSigner(ctx, stor, yamlConfig, logger)
if err != nil {
stor.Close()
return nil, fmt.Errorf("failed to create local signer: %w", err)
}
dexConfig.Signer = localSigner
dexSrv, err := server.NewServer(ctx, dexConfig)
if err != nil {
stor.Close()
@@ -182,6 +200,32 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
}, nil
}
func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) {
// Parse expiry durations
idTokensValidFor := 24 * time.Hour // default
if yamlConfig.Expiry.IDTokens != "" {
var err error
idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err)
}
}
localSignerConfig := &signer.LocalConfig{
KeysRotationPeriod: "720h", // 30 Days
}
if yamlConfig.Expiry.SigningKeys != "" {
if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil {
return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err)
}
localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys
}
return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger)
}
// initializeStorage sets up connectors, passwords, and clients in storage
func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error {
if cfg.EnablePasswordDB {
@@ -241,6 +285,8 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
old.RedirectURIs = client.RedirectURIs
old.Name = client.Name
old.Public = client.Public
old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs
old.MFAChain = client.MFAChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update client %s: %w", client.ID, err)
@@ -253,9 +299,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
cfg := yamlConfig.ToServerConfig(stor, logger)
cfg.PrometheusRegistry = prometheus.NewRegistry()
if cfg.RotateKeysAfter == 0 {
cfg.RotateKeysAfter = 24 * 30 * time.Hour
}
if cfg.IDTokensValidFor == 0 {
cfg.IDTokensValidFor = 24 * time.Hour
}
@@ -450,10 +493,34 @@ func (p *Provider) Storage() storage.Storage {
return p.storage
}
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
for _, clientID := range clientIDs {
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
}
return nil
}
// Handler returns the Dex server as an http.Handler for embedding in another server.
// The handler expects requests with path prefix "/oauth2/".
func (p *Provider) Handler() http.Handler {
return p.dexServer
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Dex's /logout endpoint requires id_token_hint for RP-initiated logout with
// post_logout_redirect_uri. If the dashboard calls logout without one, avoid
// rendering Dex's non-actionable Bad Request page and send the user home.
if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
p.dexServer.ServeHTTP(w, r)
})
}
// CreateUser creates a new user with the given email, username, and password.

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -144,6 +146,30 @@ func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) {
assert.Equal(t, knownEncodedID, reEncoded)
}
func TestHandlerRedirectsLogoutWithoutIDTokenHint(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-logout-handler-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
provider, err := NewProvider(ctx, &Config{
Issuer: "http://localhost:5556/oauth2",
Port: 5556,
DataDir: tmpDir,
})
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
req := httptest.NewRequest(http.MethodGet, "/oauth2/logout?post_logout_redirect_uri=https://example.com", nil)
rec := httptest.NewRecorder()
provider.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusSeeOther, rec.Code)
require.Equal(t, "/", rec.Header().Get("Location"))
}
func TestCreateUserInTempDB(t *testing.T) {
ctx := context.Background()

View File

@@ -0,0 +1,12 @@
{{ template "header.html" . }}
<script>globalThis.location.replace("/");</script>
<noscript>
<div class="nb-card">
<h1 class="nb-heading">Redirecting…</h1>
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
</div>
</noscript>
{{ template "footer.html" . }}

View File

@@ -0,0 +1,14 @@
{{ template "header.html" . }}
<div class="nb-card">
<h1 class="nb-heading">Logged Out</h1>
<p class="nb-subheading">You have been successfully logged out.</p>
{{ if .BackURL }}
<div class="nb-back-link">
<a href="{{ .BackURL }}" class="nb-link">&larr; Back to Application</a>
</div>
{{ end }}
</div>
{{ template "footer.html" . }}

View File

@@ -18,6 +18,7 @@
id="login"
name="login"
class="nb-input"
autocomplete="username"
placeholder="Enter your {{ .UsernamePrompt | lower }}"
{{ if .Username }}value="{{ .Username }}"{{ else }}autofocus{{ end }}
required
@@ -31,6 +32,7 @@
id="password"
name="password"
class="nb-input"
autocomplete="current-password"
placeholder="Enter your password"
{{ if .Invalid }}autofocus{{ end }}
required

View File

@@ -0,0 +1,44 @@
{{ template "header.html" . }}
<div class="nb-card">
<h1 class="nb-heading">Two-factor authentication</h1>
{{ if not (eq .QRCode "") }}
<p class="nb-subheading">Scan the QR code below using your authenticator app, then enter the code.</p>
<div style="text-align: center; margin: 1em 0;">
<img src="data:image/png;base64,{{ .QRCode }}" alt="QR code" width="200" height="200"/>
</div>
{{ else }}
<p class="nb-subheading">Enter the code from your authenticator app.</p>
{{ end }}
<form method="post" action="{{ .PostURL }}">
{{ if .Invalid }}
<div class="nb-error">
Invalid code. Please try again.
</div>
{{ end }}
<div class="nb-form-group">
<label class="nb-label" for="totp">One-time code</label>
<input
type="text"
id="totp"
name="totp"
class="nb-input"
inputmode="numeric"
pattern="[0-9]*"
maxlength="6"
autocomplete="one-time-code"
placeholder="000000"
autofocus
required
>
</div>
<button type="submit" id="submit-login" class="nb-btn">
Verify
</button>
</form>
</div>
{{ template "footer.html" . }}

View File

@@ -0,0 +1,12 @@
{{ template "header.html" . }}
<script>globalThis.location.replace("/");</script>
<noscript>
<div class="nb-card">
<h1 class="nb-heading">Redirecting…</h1>
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
</div>
</noscript>
{{ template "footer.html" . }}

View File

@@ -221,9 +221,13 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
@@ -570,7 +574,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
@@ -580,7 +584,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
return c.bufferSendUpdateAccountPeers(ctx, accountID)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
@@ -616,7 +620,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -2,6 +2,7 @@ package manager
import (
"context"
"math/rand"
"sync"
"time"
@@ -11,240 +12,344 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
// cleanupWindow is the small grace period added on top of the
// staleness horizon before a sweep fires. It absorbs minor clock
// skew between the management server and the database and avoids
// firing a sweep right at the boundary where last_seen could still
// be one tick under the threshold.
cleanupWindow = 1 * time.Minute
// initialLoadMinDelay and initialLoadMaxDelay bracket the random
// delay applied before the post-restart catch-up query runs. Spread
// across replicas this prevents a thundering herd of catch-up
// queries hitting the database simultaneously after a deploy.
initialLoadMinDelay = 8 * time.Minute
initialLoadMaxDelay = 10 * time.Minute
)
var (
timeNow = time.Now
)
type ephemeralPeer struct {
id string
accountID string
deadline time.Time
next *ephemeralPeer
// accountEntry is the per-account state held by the cleanup tracker.
// We don't track which peers are pending — the sweep query gets the
// authoritative list straight from the database every time. We only
// need to know the latest disconnect we've observed for this account
// (so we can decide when it's safe to drop the entry) and the timer
// that will fire the next sweep.
type accountEntry struct {
lastDisconnectedAt time.Time
timer *time.Timer
}
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
// EphemeralManager tracks accounts that may have ephemeral peers in
// need of cleanup and runs a per-account sweep at the appropriate
// time. State is in-memory and account-scoped: a sweep deletes any
// ephemeral peer in the account that has been disconnected for at
// least lifeTime, then either drops the account from the tracker
// (when no recent disconnects have arrived) or re-arms the timer.
type EphemeralManager struct {
store store.Store
peersManager peers.Manager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
peersLock sync.Mutex
timer *time.Timer
accountsLock sync.Mutex
accounts map[string]*accountEntry
// initialLoadTimer is the one-shot timer used to defer the
// post-restart catch-up query; held so Stop() can cancel it.
initialLoadTimer *time.Timer
// stopped is flipped by Stop() so any timer that fires after
// teardown becomes a no-op instead of touching a half-dismantled
// store.
stopped bool
lifeTime time.Duration
cleanupWindow time.Duration
// initialLoadDelay returns the wall-clock delay to wait before
// running the post-restart catch-up query. Pluggable so tests can
// fire the load immediately.
initialLoadDelay func() time.Duration
// bgCtx is the long-lived context captured at LoadInitialPeers
// time. Timer-driven sweeps use it because they fire long after
// the original gRPC handler ctx that produced an OnPeerDisconnected
// call has been cancelled.
bgCtx context.Context
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
// no-op when the receiver is nil so deployments without an app
// metrics provider work unchanged.
metrics *telemetry.EphemeralPeersMetrics
}
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
store: store,
peersManager: peersManager,
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
store: store,
peersManager: peersManager,
accounts: make(map[string]*accountEntry),
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
initialLoadDelay: defaultInitialLoadDelay,
}
}
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
// SetMetrics attaches a metrics collector. Pass nil to detach.
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
e.accountsLock.Lock()
e.metrics = m
e.accountsLock.Unlock()
}
// LoadInitialPeers schedules the post-restart catch-up query for a
// random moment 8-10 minutes from now. Returns immediately. The
// catch-up populates the per-account tracker from the database so any
// peers that disconnected before the restart still get cleaned up.
//
// The random delay is critical: without it, every management replica
// hitting the same Postgres instance after a deploy would issue the
// catch-up query simultaneously.
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers(ctx)
if e.headPeer != nil {
e.timer = time.AfterFunc(e.lifeTime, func() {
e.cleanup(ctx)
})
}
}
// Stop timer
func (e *EphemeralManager) Stop() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.timer != nil {
e.timer.Stop()
}
}
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
e.bgCtx = ctx
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.removePeer(peer.ID)
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
e.timer.Stop()
e.timer = nil
}
delay := e.initialLoadDelay()
log.WithContext(ctx).Infof("ephemeral peer initial load scheduled in %s", delay)
e.initialLoadTimer = time.AfterFunc(delay, func() {
e.loadInitialAccounts(e.bgCtx)
})
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the EphemeralLifeTime period.
// Stop cancels the deferred initial load and any per-account timers.
func (e *EphemeralManager) Stop() {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
e.stopped = true
if e.initialLoadTimer != nil {
e.initialLoadTimer.Stop()
e.initialLoadTimer = nil
}
for _, entry := range e.accounts {
if entry.timer != nil {
entry.timer.Stop()
}
}
e.accounts = make(map[string]*accountEntry)
}
// OnPeerConnected is a no-op in the account-scoped design. The sweep
// query filters out connected peers at the database level, so we don't
// need an explicit "remove from list" signal when a peer reconnects.
// Kept on the interface to preserve the existing call sites.
func (e *EphemeralManager) OnPeerConnected(_ context.Context, _ *nbpeer.Peer) {
}
// OnPeerDisconnected registers a disconnect for the peer's account and
// arms a sweep if one isn't already scheduled. Non-ephemeral peers are
// ignored.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.isPeerOnList(peer.ID) {
return
}
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
if e.timer == nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
})
}
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := e.newDeadLine()
for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t)
}
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
}
func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
e.peersLock.Lock()
now := timeNow()
for p := e.headPeer; p != nil; p = p.next {
if now.Before(p.deadline) {
break
}
deletePeers[p.id] = p
e.headPeer = p.next
if p.next == nil {
e.tailPeer = nil
}
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
if e.headPeer != nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
entry, existed := e.accounts[peer.AccountID]
if !existed {
entry = &accountEntry{}
e.accounts[peer.AccountID] = entry
e.metrics.IncPending()
}
entry.lastDisconnectedAt = now
if entry.timer == nil {
delay := e.lifeTime + e.cleanupWindow
log.WithContext(ctx).Tracef("ephemeral: scheduling sweep for account %s in %s", peer.AccountID, delay)
accountID := peer.AccountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(e.bgCtxOrFallback(ctx), accountID)
})
} else {
e.timer = nil
}
e.peersLock.Unlock()
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
}
}
}
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{
id: peerID,
accountID: accountID,
deadline: deadline,
// bgCtxOrFallback returns the long-lived background context captured at
// LoadInitialPeers time, falling back to the supplied ctx when the
// manager hasn't been started through LoadInitialPeers (e.g. in tests
// that drive the manager directly). Must be called with the lock held
// or before the timer is armed.
func (e *EphemeralManager) bgCtxOrFallback(ctx context.Context) context.Context {
if e.bgCtx != nil {
return e.bgCtx
}
if e.headPeer == nil {
e.headPeer = ep
}
if e.tailPeer != nil {
e.tailPeer.next = ep
}
e.tailPeer = ep
return ctx
}
func (e *EphemeralManager) removePeer(id string) {
if e.headPeer == nil {
// loadInitialAccounts runs the post-restart catch-up query and seeds
// the tracker with one entry per account that has at least one
// disconnected ephemeral peer.
func (e *EphemeralManager) loadInitialAccounts(ctx context.Context) {
accounts, err := e.store.GetEphemeralAccountsLastDisconnect(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to load ephemeral accounts on startup: %v", err)
return
}
if e.headPeer.id == id {
e.headPeer = e.headPeer.next
if e.tailPeer.id == id {
e.tailPeer = nil
now := timeNow()
added := 0
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
for accountID, lastDisc := range accounts {
// If we already learned about this account via an
// OnPeerDisconnected that arrived during the random delay
// window, prefer the live timestamp.
if _, alreadyTracked := e.accounts[accountID]; alreadyTracked {
continue
}
entry := &accountEntry{lastDisconnectedAt: lastDisc}
horizon := lastDisc.Add(e.lifeTime)
var delay time.Duration
if horizon.After(now) {
delay = horizon.Sub(now) + e.cleanupWindow
} else {
// Already past the staleness window — sweep right away
// (one cleanupWindow later, to keep startup load smooth
// when many accounts qualify at once).
delay = e.cleanupWindow
}
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
e.accounts[accountID] = entry
added++
}
e.metrics.AddPending(int64(added))
log.WithContext(ctx).Debugf("ephemeral: loaded %d account(s) for cleanup tracking", added)
}
// sweep runs the cleanup pass for a single account. It queries the
// database for disconnected ephemeral peers that have crossed the
// staleness window, deletes them via peers.Manager, and then decides
// whether to drop the account from the tracker or re-arm the timer.
func (e *EphemeralManager) sweep(ctx context.Context, accountID string) {
now := timeNow()
e.accountsLock.Lock()
entry, ok := e.accounts[accountID]
if !ok || e.stopped {
e.accountsLock.Unlock()
return
}
lastDisc := entry.lastDisconnectedAt
entry.timer = nil
e.accountsLock.Unlock()
threshold := now.Add(-e.lifeTime)
stalePeerIDs, err := e.store.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, threshold)
if err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to query stale peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
return
}
for p := e.headPeer; p.next != nil; p = p.next {
if p.next.id == id {
// if we remove the last element from the chain then set the last-1 as tail
if e.tailPeer.id == id {
e.tailPeer = p
}
p.next = p.next.next
if len(stalePeerIDs) > 0 {
log.WithContext(ctx).Tracef("ephemeral: deleting %d peer(s) for account %s", len(stalePeerIDs), accountID)
if err := e.peersManager.DeletePeers(ctx, accountID, stalePeerIDs, activity.SystemInitiator, true); err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to delete peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
return
}
e.metrics.CountCleanupRun()
e.metrics.CountPeersCleaned(int64(len(stalePeerIDs)))
}
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
entry, ok = e.accounts[accountID]
if !ok {
return
}
// Drop rule: if every disconnect we've observed has now crossed
// the staleness window, the sweep we just ran saw everything that
// could possibly need cleaning. Dropping is safe — a future
// disconnect will recreate the entry. The check uses the latest
// lastDisc, which may have advanced (concurrently with the sweep
// itself) due to a new OnPeerDisconnected, in which case we
// correctly re-arm.
horizon := entry.lastDisconnectedAt.Add(e.lifeTime)
if !horizon.After(now) {
delete(e.accounts, accountID)
e.metrics.DecPending(1)
log.WithContext(ctx).Tracef("ephemeral: dropping account %s (lastDisc=%s, horizon=%s, now=%s)",
accountID, lastDisc, horizon, now)
return
}
delay := horizon.Sub(now) + e.cleanupWindow
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
func (e *EphemeralManager) isPeerOnList(id string) bool {
for p := e.headPeer; p != nil; p = p.next {
if p.id == id {
return true
}
// rearm reschedules a sweep `delay` from now. Used after a recoverable
// error in the sweep path so the account doesn't get stuck.
func (e *EphemeralManager) rearm(ctx context.Context, accountID string, delay time.Duration) {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
return false
entry, ok := e.accounts[accountID]
if !ok {
return
}
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
func (e *EphemeralManager) newDeadLine() time.Time {
return timeNow().Add(e.lifeTime)
// defaultInitialLoadDelay returns a random duration in
// [initialLoadMinDelay, initialLoadMaxDelay). Process-wide
// math/rand is acceptable here — the delay is purely a smoothing
// jitter, not a security primitive.
func defaultInitialLoadDelay() time.Duration {
span := int64(initialLoadMaxDelay - initialLoadMinDelay)
if span <= 0 {
return initialLoadMinDelay
}
return initialLoadMinDelay + time.Duration(rand.Int63n(span))
}

View File

@@ -2,299 +2,544 @@ package manager
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// MockStore is a thin in-memory stand-in that implements only the two
// methods the EphemeralManager uses. It honors the account / ephemeral
// / connected / lastSeen attributes of each peer so the cleanup logic
// can be exercised end-to-end without bringing up sqlite or Postgres.
type MockStore struct {
store.Store
mu sync.Mutex
account *types.Account
}
func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
for _, v := range s.account.Peers {
if v.Ephemeral {
peers = append(peers, v)
func (s *MockStore) GetStaleEphemeralPeerIDsForAccount(_ context.Context, accountID string, olderThan time.Time) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.account == nil || s.account.Id != accountID {
return nil, nil
}
var ids []string
for _, p := range s.account.Peers {
if !p.Ephemeral {
continue
}
if p.Status == nil || p.Status.Connected {
continue
}
if p.Status.LastSeen.Before(olderThan) {
ids = append(ids, p.ID)
}
}
return peers, nil
return ids, nil
}
type MockAccountManager struct {
mu sync.Mutex
nbAccount.Manager
store *MockStore
deletePeerCalls int
bufferUpdateCalls map[string]int
wg *sync.WaitGroup
}
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
a.mu.Lock()
defer a.mu.Unlock()
a.deletePeerCalls++
delete(a.store.account.Peers, peerID)
if a.wg != nil {
a.wg.Done()
func (s *MockStore) GetEphemeralAccountsLastDisconnect(_ context.Context) (map[string]time.Time, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := map[string]time.Time{}
if s.account == nil {
return out, nil
}
return nil
}
func (a *MockAccountManager) GetDeletePeerCalls() int {
a.mu.Lock()
defer a.mu.Unlock()
return a.deletePeerCalls
}
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
a.bufferUpdateCalls = make(map[string]int)
var latest time.Time
hasAny := false
for _, p := range s.account.Peers {
if !p.Ephemeral || p.Status == nil || p.Status.Connected {
continue
}
if !hasAny || p.Status.LastSeen.After(latest) {
latest = p.Status.LastSeen
hasAny = true
}
}
a.bufferUpdateCalls[accountID]++
}
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
return 0
if hasAny {
out[s.account.Id] = latest
}
return a.bufferUpdateCalls[accountID]
return out, nil
}
func (a *MockAccountManager) GetStore() store.Store {
return a.store
}
func TestNewManager(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
// withFakeClock pins timeNow to a settable value for the duration of t.
// Returns a getter and a setter so subtests can advance virtual time.
func withFakeClock(t *testing.T, start time.Time) (get func() time.Time, set func(time.Time)) {
t.Helper()
var mu sync.Mutex
now := start
timeNow = func() time.Time {
return startTime
mu.Lock()
defer mu.Unlock()
return now
}
t.Cleanup(func() { timeNow = time.Now })
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
}
return func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
}, func(v time.Time) {
mu.Lock()
defer mu.Unlock()
now = v
}
}
func TestNewManagerPeerConnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
// newManagerForTest builds a manager with short timers and no random
// initial-load delay so tests run instantly.
func newManagerForTest(t *testing.T, st store.Store, peersMgr peers.Manager) *EphemeralManager {
t.Helper()
mgr := NewEphemeralManager(st, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
mgr.initialLoadDelay = func() time.Duration { return 0 }
t.Cleanup(mgr.Stop)
return mgr
}
func TestNewManagerPeerDisconnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
// TestOnPeerDisconnected_RegistersAndSweeps drives the OnPeerDisconnected
// path with a fake clock: a single ephemeral peer disconnects, we
// advance past the staleness window, and the sweep deletes it.
func TestOnPeerDisconnected_RegistersAndSweeps(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for the one disconnected peer
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
}
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
const (
ephemeralPeers = 10
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mockStore := &MockStore{}
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
wg := &sync.WaitGroup{}
wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
wg: wg,
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
peersMgr := peers.NewMockManager(ctrl)
// Set up expectation that DeletePeers will be called once with all peer IDs
peersManager.EXPECT().
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
// Simulate the actual deletion behavior
for _, peerID := range peerIDs {
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
return err
}
var deletedMu sync.Mutex
var deleted []string
var deleteCalls atomic.Int32
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, accountID string, peerIDs []string, _ string, _ bool) error {
deleteCalls.Add(1)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
mockStore.mu.Unlock()
deletedMu.Lock()
deleted = append(deleted, peerIDs...)
deletedMu.Unlock()
return nil
}).
Times(1)
}).AnyTimes()
mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
mgr := newManagerForTest(t, mockStore, peersMgr)
// Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
// One ephemeral peer that disconnected "now".
now := getNow()
p := &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
// Advance past lifeTime + cleanupWindow so the timer-driven sweep fires.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool { return deleteCalls.Load() >= 1 }, 2*time.Second, 5*time.Millisecond,
"sweep should fire and delete the stale peer")
deletedMu.Lock()
deletedCopy := append([]string(nil), deleted...)
deletedMu.Unlock()
require.Equal(t, []string{"p1"}, deletedCopy, "only the one ephemeral peer should be deleted")
}
// TestOnPeerDisconnected_NonEphemeralIgnored: a non-ephemeral disconnect
// must not register the account or arm any timer.
func TestOnPeerDisconnected_NonEphemeralIgnored(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// No DeletePeers expectation — must not be called.
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: false,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "non-ephemeral disconnect must not register an account")
mgr.accountsLock.Unlock()
}
// TestSweep_DropsAccountWhenIdle: after a sweep cleans the stale peers,
// if no more disconnects have arrived the account must be dropped from
// the in-memory tracker.
func TestSweep_DropsAccountWhenIdle(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should be dropped after sweep with no new disconnects")
}
// TestSweep_ReArmsWhenNewDisconnectArrived: simulate the race where a
// fresh disconnect arrives just before the sweep fires. The sweep must
// observe the updated lastDisc and re-arm rather than drop.
func TestSweep_ReArmsWhenNewDisconnectArrived(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
p1 := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p1.ID] = p1
mgr.OnPeerDisconnected(context.Background(), p1)
// Advance most of the way toward the first sweep, then introduce
// a fresh disconnect that resets lastDisc.
setNow(now.Add(mgr.lifeTime - 10*time.Millisecond))
p2 := &nbpeer.Peer{ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p2.ID] = p2
mgr.OnPeerDisconnected(context.Background(), p2)
// Push past p1's staleness so the first sweep runs and cleans p1
// but observes p2 already on the account entry. It must re-arm.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p1 should be cleaned at the first sweep")
// The account should still be tracked because p2 is younger than lifeTime
// from the sweep's vantage point at this moment.
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account should remain tracked because p2's disconnect kept it active")
// Push past p2's staleness; second sweep cleans p2 and drops the account.
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should drop after the final sweep")
}
// TestSweep_BatchesPeersPerAccount: many ephemeral peers disconnect on
// the same account; a single sweep must delete them all in one
// DeletePeers call.
func TestSweep_BatchesPeersPerAccount(t *testing.T) {
const ephemeralPeers = 8
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
deleteBatches := make(chan []string, 4)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
cp := append([]string(nil), peerIDs...)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
deleteBatches <- cp
return nil
}).Times(1)
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
for i := 0; i < ephemeralPeers; i++ {
id := fmt.Sprintf("p-%d", i)
// Stagger by a fraction of cleanupWindow so they all fall on
// the same sweep tick.
when := now.Add(time.Duration(i) * time.Millisecond)
p := &nbpeer.Peer{ID: id, AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: when}}
mockStore.account.Peers[id] = p
mgr.OnPeerDisconnected(context.Background(), p)
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
// Advance time past the lifetime to trigger cleanup
startTime = startTime.Add(testLifeTime + testCleanupWindow)
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait for all deletions to complete
wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
select {
case batch := <-deleteBatches:
require.Len(t, batch, ephemeralPeers, "all peers should be deleted in a single batch")
case <-time.After(2 * time.Second):
t.Fatal("expected one batched DeletePeers call")
}
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
// TestLoadInitialAccounts_SeedsFromStore exercises the post-restart
// catch-up path: pre-populate the store, point the manager at it, and
// confirm both already-stale and not-yet-stale peers get cleaned at
// their proper times.
func TestLoadInitialAccounts_SeedsFromStore(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: false,
}
store.account.Peers[p.ID] = p
now := getNow()
// p-stale: already past the staleness window when load runs.
mockStore.account.Peers["p-stale"] = &nbpeer.Peer{
ID: "p-stale", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now.Add(-time.Hour)},
}
// p-fresh: disconnected but not yet stale.
mockStore.account.Peers["p-fresh"] = &nbpeer.Peer{
ID: "p-fresh", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
}
for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: true,
}
store.account.Peers[p.ID] = p
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
// Drive loadInitialAccounts directly with the fake-clock-aware now.
mgr.loadInitialAccounts(context.Background())
// First sweep should fire shortly (cleanupWindow) for the stale peer.
setNow(now.Add(5 * mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-stale"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-stale should be deleted on the first sweep")
// p-fresh is not yet stale; advance past its window.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-fresh"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-fresh should be deleted once it crosses the staleness window")
}
// TestStop_CancelsPendingWork verifies that Stop() cancels both the
// deferred initial load and per-account sweep timers and that
// subsequent OnPeerDisconnected calls are ignored.
func TestStop_CancelsPendingWork(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// DeletePeers must NOT be called after Stop.
mgr := NewEphemeralManager(mockStore, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
// Use a long delay so the initial-load timer is still pending.
mgr.initialLoadDelay = func() time.Duration { return time.Hour }
mgr.LoadInitialPeers(context.Background())
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.NotNil(t, mgr.initialLoadTimer, "initial-load timer should be armed")
require.Len(t, mgr.accounts, 1, "account should be tracked after disconnect")
mgr.accountsLock.Unlock()
mgr.Stop()
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "Stop should clear tracked accounts")
require.True(t, mgr.stopped, "stopped flag must be set")
mgr.accountsLock.Unlock()
// Post-stop disconnect must be ignored.
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "disconnects after Stop must be ignored")
mgr.accountsLock.Unlock()
}
// TestOnPeerConnected_IsNoop: the OnPeerConnected hook is preserved on
// the interface but does nothing in the per-account model — the sweep
// query filters connected peers at the DB level.
func TestOnPeerConnected_IsNoop(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "disconnect should track the account")
mgr.accountsLock.Unlock()
mgr.OnPeerConnected(context.Background(), &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "OnPeerConnected must be a no-op")
mgr.accountsLock.Unlock()
}
// TestSweep_StoreErrorReArms: if the stale-peer query fails, the
// account must remain tracked and a follow-up sweep gets scheduled.
func TestSweep_StoreErrorReArms(t *testing.T) {
mockStore := &erroringStore{
MockStore: MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)},
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
mockStore.fail.Store(true)
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait until the failing sweep has run at least once.
require.Eventually(t, func() bool { return mockStore.failedCalls.Load() >= 1 },
2*time.Second, 5*time.Millisecond, "expected at least one failing sweep")
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account must remain tracked after a sweep error")
// Recover and ensure the rearmed sweep cleans up.
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mockStore.fail.Store(false)
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "rearmed sweep should clean up after the store recovers")
}
// erroringStore is a MockStore that can be flipped into a failing mode
// to exercise the sweep's error-rearm path.
type erroringStore struct {
MockStore
fail atomic.Bool
failedCalls atomic.Int32
}
func (s *erroringStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
if s.fail.Load() {
s.failedCalls.Add(1)
return nil, errors.New("synthetic store error")
}
return s.MockStore.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
}
// TestDefaultInitialLoadDelay confirms the jitter falls inside the
// documented [8m, 10m) range — sanity check for the production timer.
func TestDefaultInitialLoadDelay(t *testing.T) {
for i := 0; i < 1000; i++ {
d := defaultInitialLoadDelay()
assert.GreaterOrEqual(t, d, initialLoadMinDelay)
assert.Less(t, d, initialLoadMaxDelay)
}
}
@@ -351,3 +596,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
}
return acc
}
// silence the import "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
// (still needed indirectly for ephemeral.EphemeralLifeTime in production paths).
var _ = ephemeral.EphemeralLifeTime

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// Verify the target cluster is in the available clusters for this account
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -298,6 +299,34 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("get public cluster addresses: %w", err)
}
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
for _, addr := range byopAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
for _, addr := range publicAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
return merged, nil
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1

View File

@@ -0,0 +1,154 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"shared.example.com", "byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "public cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -11,15 +11,19 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
Disconnect(ctx context.Context, proxyID, sessionID string) error
Heartbeat(ctx context.Context, p *Proxy) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.

View File

@@ -16,11 +16,16 @@ type store interface {
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// Manager handles all proxy operations
@@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
@@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
SessionID: sessionID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
AccountID: accountID,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Status: proxy.StatusConnected,
Capabilities: caps,
}
@@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err
@@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro
}
// Heartbeat updates the proxy's last seen timestamp.
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
return err
@@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
@@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
return addresses, nil
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
clusters, err := m.store.GetActiveProxyClusters(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
return nil, err
}
return clusters, nil
}
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
@@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
return nil, err
}
return addresses, nil
}
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
return m.store.GetProxyByAccountID(ctx, accountID)
}
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
return m.store.CountProxiesByAccountID(ctx, accountID)
}
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
if err != nil {
return false, err
}
return !conflicting, nil
}
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
return err
}
return nil
}

View File

@@ -0,0 +1,337 @@
package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
if m.saveProxyFunc != nil {
return m.saveProxyFunc(ctx, p)
}
return nil
}
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
if m.disconnectProxyFunc != nil {
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, p)
}
return nil
}
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveProxyClusterAddressesFunc != nil {
return m.getActiveProxyClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveProxyClusterAddressesForAccFunc != nil {
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)
}
return nil
}
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
if m.getProxyByAccountIDFunc != nil {
return m.getProxyByAccountIDFunc(ctx, accountID)
}
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
if m.countProxiesByAccountIDFunc != nil {
return m.countProxiesByAccountIDFunc(ctx, accountID)
}
return 0, nil
}
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
if m.isClusterAddressConflictingFunc != nil {
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
}
return false, nil
}
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if m.deleteAccountClusterFunc != nil {
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
}
return nil
}
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")
m, err := NewManager(s, meter)
if err != nil {
panic(err)
}
return m
}
func TestConnect_WithAccountID(t *testing.T) {
accountID := "acc-123"
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Equal(t, "proxy-1", savedProxy.ID)
assert.Equal(t, "session-1", savedProxy.SessionID)
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
assert.Equal(t, &accountID, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
assert.NotNil(t, savedProxy.ConnectedAt)
}
func TestConnect_WithoutAccountID(t *testing.T) {
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Nil(t, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
}
func TestConnect_StoreError(t *testing.T) {
s := &mockStore{
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
assert.Error(t, err)
}
func TestIsClusterAddressAvailable(t *testing.T) {
tests := []struct {
name string
conflicting bool
storeErr error
wantResult bool
wantErr bool
}{
{
name: "available - no conflict",
conflicting: false,
wantResult: true,
},
{
name: "not available - conflict exists",
conflicting: true,
wantResult: false,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
return tt.conflicting, tt.storeErr
},
}
mgr := newTestManager(s)
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
})
}
}
func TestCountAccountProxies(t *testing.T) {
tests := []struct {
name string
count int64
storeErr error
wantCount int64
wantErr bool
}{
{
name: "no proxies",
count: 0,
wantCount: 0,
},
{
name: "one proxy",
count: 1,
wantCount: 1,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
return tt.count, tt.storeErr
},
}
mgr := newTestManager(s)
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
func TestGetAccountProxy(t *testing.T) {
accountID := "acc-123"
t.Run("found", func(t *testing.T) {
expected := &proxy.Proxy{
ID: "proxy-1",
ClusterAddress: "byop.example.com",
AccountID: &accountID,
Status: proxy.StatusConnected,
}
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
assert.Equal(t, accountID, accID)
return expected, nil
},
}
mgr := newTestManager(s)
p, err := mgr.GetAccountProxy(context.Background(), accountID)
require.NoError(t, err)
assert.Equal(t, expected, p)
})
t.Run("not found", func(t *testing.T) {
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, errors.New("not found")
},
}
mgr := newTestManager(s)
_, err := mgr.GetAccountProxy(context.Background(), accountID)
assert.Error(t, err)
})
}
func TestDeleteAccountCluster(t *testing.T) {
t.Run("success", func(t *testing.T) {
var deletedCluster, deletedAccount string
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
deletedCluster = clusterAddress
deletedAccount = accountID
return nil
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
require.NoError(t, err)
assert.Equal(t, "cluster.example.com", deletedCluster)
assert.Equal(t, "acc-123", deletedAccount)
})
t.Run("store error", func(t *testing.T) {
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
assert.Error(t, err)
})
}
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
expected := []string{"byop.example.com"}
s := &mockStore{
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return expected, nil
},
}
mgr := newTestManager(s)
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, expected, result)
}

View File

@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
}
// Disconnect mocks base method.
@@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
ret0, _ := ret[0].([]Cluster)
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
// Heartbeat mocks base method.
@@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
}
// GetAccountProxy mocks base method.
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountProxy indicates an expected call of GetAccountProxy.
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
}
// CountAccountProxies mocks base method.
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAccountProxies indicates an expected call of CountAccountProxies.
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
}
// IsClusterAddressAvailable mocks base method.
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller

View File

@@ -1,6 +1,13 @@
package proxy
import "time"
import (
"time"
)
const (
StatusConnected = "connected"
StatusDisconnected = "disconnected"
)
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
@@ -21,6 +28,7 @@ type Proxy struct {
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
@@ -36,6 +44,8 @@ func (Proxy) TableName() string {
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
ID string
Address string
ConnectedProxies int
SelfHosted bool
}

View File

@@ -0,0 +1,195 @@
package proxytoken
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
store store.Store
permissionsManager permissions.Manager
}
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{store: s, permissionsManager: permissionsManager}
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
}
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
var req api.ProxyTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" || len(req.Name) > 255 {
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
return
}
var expiresIn time.Duration
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
if err != nil {
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
}
resp := make([]api.ProxyToken, 0, len(tokens))
for _, token := range tokens {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokenID := mux.Vars(r)["tokenId"]
if tokenID == "" {
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
resp := api.ProxyToken{
Id: token.ID,
Name: token.Name,
Revoked: token.Revoked,
}
if !token.CreatedAt.IsZero() {
resp.CreatedAt = token.CreatedAt
}
if token.ExpiresAt != nil {
resp.ExpiresAt = token.ExpiresAt
}
if token.LastUsed != nil {
resp.LastUsed = token.LastUsed
}
return resp
}
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
base := toProxyTokenResponse(&generated.ProxyAccessToken)
plainToken := string(generated.PlainToken)
return api.ProxyTokenCreated{
Id: base.Id,
Name: base.Name,
CreatedAt: base.CreatedAt,
ExpiresAt: base.ExpiresAt,
LastUsed: base.LastUsed,
Revoked: base.Revoked,
PlainToken: plainToken,
}
}

View File

@@ -0,0 +1,275 @@
package proxytoken
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func authContext(accountID, userID string) context.Context {
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
AccountId: accountID,
UserId: userID,
})
}
func TestCreateToken_AccountScoped(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "my-token"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp api.ProxyTokenCreated
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
assert.NotEmpty(t, resp.PlainToken)
assert.Equal(t, "my-token", resp.Name)
assert.False(t, resp.Revoked)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.AccountID)
assert.Equal(t, accountID, *savedToken.AccountID)
assert.Equal(t, "user-1", savedToken.CreatedBy)
}
func TestCreateToken_WithExpiration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "expiring-token", "expires_in": 3600}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.ExpiresAt)
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
}
func TestCreateToken_EmptyName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": ""}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestCreateToken_PermissionDenied(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": "test"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestListTokens(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
now := time.Now()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.listTokens(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp []api.ProxyToken
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
require.Len(t, resp, 2)
assert.Equal(t, "tok-1", resp[0].Id)
assert.False(t, resp[0].Revoked)
assert.Equal(t, "tok-2", resp[1].Id)
assert.True(t, resp[1].Revoked)
}
func TestRevokeToken_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
Name: "test-token",
AccountID: &accountID,
}, nil)
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext(accountID, "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRevokeToken_WrongAccount(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
otherAccount := "acc-other"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: &otherAccount,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRevokeToken_ManagementWideToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: nil,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}

View File

@@ -10,6 +10,7 @@ import (
type Manager interface {
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
@@ -28,4 +29,5 @@ type Manager interface {
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
}

View File

@@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()

View File

@@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
@@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
SelfHosted: c.SelfHosted,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusterAddress := mux.Vars(r)["clusterAddress"]
if clusterAddress == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
return
}
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -122,7 +122,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
return m.store.GetActiveProxyClusters(ctx, accountID)
}
// DeleteAccountCluster removes all proxy registrations for the given cluster address
// owned by the account.
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
@@ -292,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
@@ -307,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -421,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
@@ -434,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -538,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
svcForCaps.ProxyCluster = effectiveCluster
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
return nil, err
}
// Validate subdomain requirement *before* the transaction: the underlying
// capability lookup talks to the main DB pool, and SQLite's single-connection
// pool would self-deadlock if this ran while the tx already held the only
// connection.
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
return nil, err
}
var updateInfo serviceUpdateInfo
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
})
return &updateInfo, err
@@ -571,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
return existing.ProxyCluster, nil
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
@@ -589,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
return err
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
@@ -614,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -624,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
// because the transaction already holds the only connection.
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
svc.ProxyCluster = newCluster
}
if effectiveCluster != "" {
svc.ProxyCluster = effectiveCluster
}
return nil
}
@@ -986,6 +1003,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
return services, nil
}
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
}
// HTTP/HTTPS: build full URL
hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]")
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Host: bracketIPv6Host(hostNoBrackets),
Path: "/",
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10))
}
path := "/"
@@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
return pathMappings
}
// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as
// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6
// addresses are bracketed too since their textual form contains colons.
func bracketIPv6Host(host string) string {
if strings.HasPrefix(host, "[") {
return host
}
if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() {
return "[" + host + "]"
}
return host
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:

View File

@@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
{
name: "domain host without port is unchanged",
protocol: "http",
host: "example.com",
port: 0,
wantTarget: "http://example.com/",
},
{
name: "domain host with non-default port is unchanged",
protocol: "http",
host: "example.com",
port: 8080,
wantTarget: "http://example.com:8080/",
},
{
name: "ipv6 host without port is bracketed",
protocol: "http",
host: "fb00:cafe:1::3",
port: 0,
wantTarget: "http://[fb00:cafe:1::3]/",
},
{
name: "ipv6 host with default port omits port and brackets host",
protocol: "http",
host: "fb00:cafe:1::3",
port: 80,
wantTarget: "http://[fb00:cafe:1::3]/",
},
{
name: "ipv6 host with non-default port is bracketed",
protocol: "http",
host: "fb00:cafe:1::3",
port: 8080,
wantTarget: "http://[fb00:cafe:1::3]:8080/",
},
{
name: "ipv6 loopback without port is bracketed",
protocol: "http",
host: "::1",
port: 0,
wantTarget: "http://[::1]/",
},
{
name: "ipv6 host with 5-digit port is bracketed",
protocol: "http",
host: "fb00:cafe::1",
port: 18080,
wantTarget: "http://[fb00:cafe::1]:18080/",
},
{
name: "pre-bracketed ipv6 without port stays single-bracketed",
protocol: "http",
host: "[fb00:cafe::1]",
port: 0,
wantTarget: "http://[fb00:cafe::1]/",
},
{
name: "pre-bracketed ipv6 with port is not double-bracketed",
protocol: "http",
host: "[fb00:cafe::1]",
port: 8080,
wantTarget: "http://[fb00:cafe::1]:8080/",
},
{
name: "v4-mapped ipv6 host without port is bracketed",
protocol: "http",
host: "::ffff:10.0.0.1",
port: 0,
wantTarget: "http://[::ffff:10.0.0.1]/",
},
{
name: "full-form 8-group ipv6 without port is bracketed",
protocol: "http",
host: "fb00:cafe:1:0:0:0:0:3",
port: 0,
wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
if metrics := s.Metrics(); metrics != nil {
em.SetMetrics(metrics.EphemeralPeersMetrics())
}
return em
})
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -113,30 +114,47 @@ func (s *BaseServer) AccountManager() account.Manager {
})
}
func isMFAEnabledForAccount(accounts []*types.Account) bool {
if len(accounts) != 1 {
return false
}
settings := accounts[0].Settings
return settings != nil && settings.LocalMfaEnabled
}
func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
if embeddedEnabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
embeddedMgr, err := idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP service: %v", err)
}
return idpManager
if val := isMFAEnabledForAccount(s.Store().GetAllAccounts(context.Background())); val {
if err := embeddedMgr.SetMFAEnabled(context.Background(), val); err != nil {
log.Errorf("failed to set MFA enabled on embedded IDP: %v", err)
}
}
return embeddedMgr
}
// Fall back to external IdP service
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
idpManager, err := idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP service: %v", err)
}
return idpManager
}
return idpManager
return nil
})
}

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
@@ -50,6 +51,11 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
// ProxyTokenChecker checks whether a proxy access token is still valid.
type ProxyTokenChecker interface {
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -78,6 +84,9 @@ type ProxyServiceServer struct {
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
// Checker for proxy access token validity
tokenChecker ProxyTokenChecker
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
@@ -123,6 +132,8 @@ type proxyConnection struct {
proxyID string
sessionID string
address string
accountID *string
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
@@ -130,8 +141,19 @@ type proxyConnection struct {
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
token := GetProxyTokenFromContext(ctx)
if token == nil || token.AccountID == nil {
return nil
}
if requestAccountID == "" || *token.AccountID != requestAccountID {
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
}
return nil
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -141,6 +163,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel,
}
@@ -200,6 +223,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
@@ -217,6 +259,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
@@ -224,7 +268,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
@@ -233,10 +276,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
cancel()
if accountID != nil {
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
@@ -266,6 +312,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
@@ -286,7 +333,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
go s.heartbeat(connCtx, proxyRecord)
go s.heartbeat(connCtx, conn, proxyRecord)
select {
case err := <-errChan:
@@ -298,8 +345,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
// heartbeat updates the proxy's last_seen timestamp every minute and
// disconnects the proxy if its access token has been revoked.
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -309,6 +357,19 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
}
if conn.tokenID != "" && s.tokenChecker != nil {
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
if err != nil {
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
continue
}
if !valid {
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
conn.cancel()
return
}
}
case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return
@@ -316,8 +377,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
@@ -335,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if end > len(mappings) {
end = len(mappings)
}
for _, m := range mappings[i:end] {
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
if err != nil {
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
}
m.AuthToken = token
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
@@ -355,23 +421,25 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
}
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
var services []*rpservice.Service
var err error
if conn.accountID != nil {
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
} else {
services, err = s.serviceManager.GetGlobalServices(ctx)
}
if err != nil {
return nil, fmt.Errorf("get services from store: %w", err)
}
oidcCfg := s.GetOIDCValidationConfig()
var mappings []*proto.ProxyMapping
for _, service := range services {
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil {
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
if !proxyAcceptsMapping(conn, m) {
continue
}
@@ -380,8 +448,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
return mappings, nil
}
// isProxyAddressValid validates a proxy address
// isProxyAddressValid validates a proxy address (domain name or IP address)
func isProxyAddressValid(addr string) bool {
if addr == "" {
return false
}
if net.ParseIP(addr) != nil {
return true
}
_, err := domain.ValidateDomains([]string{addr})
return err == nil
}
@@ -405,6 +479,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
return nil, err
}
fields := log.Fields{
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
@@ -442,11 +520,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
// BYOP proxies only receive updates for their own account's services.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
updateAccountIDs := make(map[string]struct{})
for _, m := range update.Mapping {
if m.AccountId != "" {
updateAccountIDs[m.AccountId] = struct{}{}
}
}
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
connUpdate := update
if conn.accountID != nil && len(updateAccountIDs) > 0 {
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
return true
}
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
if len(filtered) == 0 {
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
@@ -463,6 +562,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
})
}
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
conn.cancel()
s.connectedProxies.Delete(proxyID)
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
}
}
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
var filtered []*proto.ProxyMapping
for _, m := range mappings {
if m.AccountId == accountID {
filtered = append(filtered, m)
}
}
return filtered
}
// GetConnectedProxies returns a list of connected proxy IDs
func (s *ProxyServiceServer) GetConnectedProxies() []string {
var proxies []string
@@ -531,6 +650,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
continue
}
conn := connVal.(*proxyConnection)
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
continue
}
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
@@ -618,6 +740,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
@@ -737,6 +863,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
protoStatus := req.GetStatus()
@@ -807,6 +937,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
serviceID := req.GetServiceId()
accountID := req.GetAccountId()
token := req.GetToken()
@@ -861,6 +995,10 @@ func strPtr(s string) *string {
}
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
@@ -989,21 +1127,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.serviceManager.GetGlobalServices(ctx)
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
}
if service.SessionPrivateKey == "" {
@@ -1101,6 +1227,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
@@ -1184,18 +1314,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain: %s", domain)
return s.serviceManager.GetServiceByDomain(ctx, domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {

View File

@@ -0,0 +1,29 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsProxyAddressValid(t *testing.T) {
tests := []struct {
name string
addr string
valid bool
}{
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
{name: "valid IPv6", addr: "::1", valid: true},
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
{name: "empty string", addr: "", valid: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
})
}
}

View File

@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}

View File

@@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
return nil
}
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
@@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
if m.err != nil {
return nil, m.err
}
for _, services := range m.proxiesByAccount {
for _, svc := range services {
if svc.Domain == domain {
return svc, nil
}
}
}
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
type hookingStream struct {
grpc.ServerStream
onSend func(*proto.GetMappingUpdateResponse)
}
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
if s.onSend != nil {
s.onSend(m)
}
return nil
}
func (s *hookingStream) Context() context.Context { return context.Background() }
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
func (s *hookingStream) SetTrailer(metadata.MD) {}
func (s *hookingStream) SendMsg(any) error { return nil }
func (s *hookingStream) RecvMsg(any) error { return nil }
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 6
const ttl = 100 * time.Millisecond
const sendDelay = 200 * time.Millisecond
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
s.tokenTTL = ttl
var validateErrs []error
stream := &hookingStream{
onSend: func(resp *proto.GetMappingUpdateResponse) {
for _, m := range resp.Mapping {
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
}
}
time.Sleep(sendDelay)
},
}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Empty(t, validateErrs,
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
}

View File

@@ -12,9 +12,12 @@ import (
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -316,6 +319,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
assert.Contains(t, err.Error(), "invalid state format")
}
func scopedCtx(accountID string) context.Context {
token := &types.ProxyAccessToken{
ID: "token-1",
AccountID: &accountID,
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func globalCtx() context.Context {
token := &types.ProxyAccessToken{
ID: "token-global",
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
assert.NoError(t, err)
}
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
err := enforceAccountScope(globalCtx(), "acc-1")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "acc-2")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "")
assert.NoError(t, err)
}
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
err := enforceAccountScope(context.Background(), "acc-1")
assert.NoError(t, err)
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))

View File

@@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
uncanceledCTX := context.WithoutCancel(ctx)
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
defer unlock()
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -318,21 +318,33 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
return nil
}
@@ -340,6 +352,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
@@ -348,6 +364,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
return nil
}
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}

Some files were not shown because too many files have changed in this diff Show More