Compare commits

...

47 Commits

Author SHA1 Message Date
mlsmaycon
5ebd39ad33 Log SQL connection pool stats periodically to monitor saturation and improve debugging. 2026-04-25 09:54:58 +02:00
mlsmaycon
69c0b96d73 Refactor fast-path Sync to log skip reasons, streamline tryFastPathSync outputs, and improve debug consistency. 2026-04-24 21:25:32 +02:00
mlsmaycon
d3ea28734c Introduce network serial caching in sync fast path, optimize DB reads, and add granular cache invalidation 2026-04-24 20:50:47 +02:00
mlsmaycon
4dddafc5a1 Add caching for ExtraSettings and peer groups in fast path to reduce DB reads. 2026-04-24 19:19:58 +02:00
mlsmaycon
8c521a7cb5 Refactor sync fast path to introduce caching for ExtraSettings and peer groups, optimize MarkPeerConnected with async writes, and reduce DB round trips. 2026-04-24 18:13:37 +02:00
mlsmaycon
ac6b73005d Upgrade cache logic in sync fast path to handle legacy entries and avoid corrupting HasUser flag. 2026-04-24 17:35:33 +02:00
mlsmaycon
cf7081e592 Refactor peer cache logic in sync fast path; consolidate and optimize write operations 2026-04-24 13:33:15 +02:00
mlsmaycon
94730fe066 Add debug log for cache hit in sync fast path 2026-04-24 12:00:32 +02:00
mlsmaycon
7e9d3485d8 [management] Cache peer snapshot + consolidate auth reads on Sync hot path
Trim the fast-path Sync handler by removing two DB round trips on cache hit:

1. Consolidate GetUserIDByPeerKey + GetAccountIDByPeerPubKey into a single
   GetPeerAuthInfoByPubKey store call. Both looked up the same peer row by
   pubkey and returned one column each; the new method SELECTs both columns
   in one query. AccountManager exposes it as GetPeerAuthInfo.

2. Extend peerSyncEntry with AccountID, PeerID, PeerKey, Ephemeral and a
   HasUser flag so the cache carries everything the fast path needs. On
   cache hit with a matching metaHash:

    - The Sync handler skips GetPeerAuthInfo entirely (entry.AccountID and
      entry.HasUser drive the loginFilter gate).
    - commitFastPath skips GetPeerByPeerPubKey by using the cached peer
      snapshot for OnPeerConnectedWithPeer.

Old cache entries from pre-step-2 shape still decode (missing fields zero
out) but IsComplete() returns false, so they fall through to the slow path
and get rewritten with the full shape on first pass. No migration needed.

Expected impact on a 16.8 s pathological Sync observed in production:
~6 s saved from eliminating one auth-read round trip, the pre-fast-path
GetPeerAuthInfo on cache hit, and GetPeerByPeerPubKey in commitFastPath.
Cache miss / cold start remain on the slow path unchanged.

Account-serial, ExtraSettings and peer-group caching — the remaining
synchronous DB reads — are deliberately left for a follow-up so the
invalidation design can be proven incrementally.
2026-04-24 11:41:59 +02:00
mlsmaycon
5993264d34 Add detailed timing logs to sync fast path operations 2026-04-24 08:07:12 +02:00
mlsmaycon
617ceab2e3 Add OnPeerConnectedWithPeer to optimize sync fast path operations 2026-04-22 22:40:31 +02:00
mlsmaycon
53deabbdb4 Add timing log for GetExtraSettings in sync fast path 2026-04-22 15:00:21 +02:00
mlsmaycon
ac3fe4343b Refactor sync fast path logging for improved clarity and timing accuracy 2026-04-22 14:24:52 +02:00
mlsmaycon
a4ae160993 Fix deferred logging function in commitFastPath for correct execution 2026-04-22 11:41:32 +02:00
mlsmaycon
3ac4263257 Add timing instrumentation for sync fast path functions 2026-04-22 01:23:44 +02:00
mlsmaycon
dc86c9655d Improve timing precision in sync fast path logging 2026-04-22 00:39:09 +02:00
mlsmaycon
66494d61af Replace Tracef with Debugf for sync fast path logging 2026-04-22 00:06:39 +02:00
mlsmaycon
46446acd30 Add detailed timing logs to sync fast path operations 2026-04-21 23:02:58 +02:00
mlsmaycon
3eb1298cb4 Refactor sync fast path tests and fix CI flakiness
- Introduce `skipOnWindows` helper to properly skip tests relying on Unix specific paths.
- Replace fixed sleep with `require.Eventually` in `waitForPeerDisconnect` to address flakiness in CI.
- Split `commitFastPath` logic out of `runFastPathSync` to close race conditions and improve clarity.
- Update tests to leverage new helpers and more precise assertions (e.g., `waitForPeerDisconnect`).
- Add `flakyStore` test helper to exercise fail-closed behavior in flag handling.
- Enhance `RunFastPathFlagRoutine` to disable the flag on store read errors.
2026-04-21 17:07:31 +02:00
mlsmaycon
93391fc68f generate only current.bin and android_current.bin on ci/cd 2026-04-21 16:49:54 +02:00
mlsmaycon
48c080b861 Replace Redis dependency with a generic cache store for fast path flag handling 2026-04-21 16:28:24 +02:00
mlsmaycon
3716838c25 Remove unused cacheKey helper and testcontainers imports, simplify Redis container setup 2026-04-21 16:17:31 +02:00
mlsmaycon
5d58000dbd Merge branch 'main' into cached-serial-check-on-sync 2026-04-21 15:55:47 +02:00
mlsmaycon
8430b06f2a [management] Add Redis-backed kill switch for Sync fast path
Gate the peer-sync fast path on a runtime flag polled from Redis so operators can roll the optimisation out gradually and flip it off without a redeploy.

Without NB_PEER_SYNC_REDIS_ADDRESS the routine stays disabled, every Sync runs the full network map path, and no entries accumulate in the peer serial cache — bit-for-bit identical to the pre-fast-path behaviour. When the env var is set, a background goroutine polls the configured key (default "peerSyncFastPath") every minute; values "1" or "true" enable the fast path, anything else disables it.

- RunFastPathFlagRoutine mirrors shared/logleveloverrider: dedicated Redis connection, background ticker, redis.Nil treated as disabled.
- NewServer takes the flag handle; tryFastPathSync and the recordPeerSyncEntry helpers short-circuit when Enabled() is false.
- invalidatePeerSyncEntry still runs on Login regardless of flag state.
- NewFastPathFlag(bool) exposed for tests and callers that need to force a state without going through Redis.
2026-04-21 15:52:34 +02:00
Zoltan Papp
5a89e6621b [client] Supress ICE signaling (#5820)
* [client] Suppress ICE signaling and periodic offers in force-relay mode

When NB_FORCE_RELAY is enabled, skip WorkerICE creation entirely,
suppress ICE credentials in offer/answer messages, disable the
periodic ICE candidate monitor, and fix isConnectedOnAllWay to
only check relay status so the guard stops sending unnecessary offers.

* [client] Dynamically suppress ICE based on remote peer's offer credentials

Track whether the remote peer includes ICE credentials in its
offers/answers. When remote stops sending ICE credentials, skip
ICE listener dispatch, suppress ICE credentials in responses, and
exclude ICE from the guard connectivity check. When remote resumes
sending ICE credentials, re-enable all ICE behavior.

* [client] Fix nil SessionID panic and force ICE teardown on relay-only transition

Fix nil pointer dereference in signalOfferAnswer when SessionID is nil
(relay-only offers). Close stale ICE agent immediately when remote peer
stops sending ICE credentials to avoid traffic black-hole during the
ICE disconnect timeout.

* [client] Add relay-only fallback check when ICE is unavailable

Ensure the relay connection is supported with the peer when ICE is disabled to prevent connectivity issues.

* [client] Add tri-state connection status to guard for smarter ICE retry (#5828)

* [client] Add tri-state connection status to guard for smarter ICE retry

Refactor isConnectedOnAllWay to return a ConnStatus enum (Connected,
Disconnected, PartiallyConnected) instead of a boolean. When relay is
up but ICE is not (PartiallyConnected), limit ICE offers to 3 retries
with exponential backoff then fall back to hourly attempts, reducing
unnecessary signaling traffic. Fully disconnected peers continue to
retry aggressively. External events (relay/ICE disconnect, signal/relay
reconnect) reset retry state to give ICE a fresh chance.

* [client] Clarify guard ICE retry state and trace log trigger

Split iceRetryState.attempt into shouldRetry (pure predicate) and
enterHourlyMode (explicit state transition) so the caller in
reconnectLoopWithRetry reads top-to-bottom. Restore the original
trace-log behavior in isConnectedOnAllWay so it only logs on full
disconnection, not on the new PartiallyConnected state.

* [client] Extract pure evalConnStatus and add unit tests

Split isConnectedOnAllWay into a thin method that snapshots state and
a pure evalConnStatus helper that takes a connStatusInputs struct, so
the tri-state decision logic can be exercised without constructing
full Worker or Handshaker objects. Add table-driven tests covering
force-relay, ICE-unavailable and fully-available code paths, plus
unit tests for iceRetryState budget/hourly transitions and reset.

* [client] Improve grammar in logs and refactor ICE credential checks
2026-04-21 15:52:08 +02:00
Misha Bragin
06dfa9d4a5 [management] replace mailru/easyjson with netbirdio/easyjson fork (#5938) 2026-04-21 13:59:35 +02:00
Misha Bragin
45d9ee52c0 [self-hosted] add reverse proxy retention fields to combined YAML (#5930) 2026-04-21 10:21:11 +02:00
Zoltan Papp
3098f48b25 [client] fix ios network addresses mac filter (#5906)
* fix(client): skip MAC address filter for network addresses on iOS

iOS does not expose hardware (MAC) addresses due to Apple's privacy
restrictions (since iOS 14), causing networkAddresses() to return an
empty list because all interfaces are filtered out by the HardwareAddr
check. Move networkAddresses() to platform-specific files so iOS can
skip this filter.
2026-04-20 11:49:38 +02:00
Zoltan Papp
7f023ce801 [client] Android debug bundle support (#5888)
Add Android debug bundle support with Troubleshoot UI
2026-04-20 11:26:30 +02:00
Michael Uray
e361126515 [client] Fix WGIface.Close deadlock when DNS filter hook re-enters GetDevice (#5916)
WGIface.Close() took w.mu and held it across w.tun.Close(). The
underlying wireguard-go device waits for its send/receive goroutines to
drain before Close() returns, and some of those goroutines re-enter
WGIface during shutdown. In particular, the userspace packet filter DNS
hook in client/internal/dns.ServiceViaMemory.filterDNSTraffic calls
s.wgInterface.GetDevice() on every packet, which also needs w.mu. With
the Close-side holding the mutex, the read goroutine blocks in
GetDevice and Close waits forever for that goroutine to exit:

  goroutine N (TestDNSPermanent_updateUpstream):
    WGIface.Close -> holds w.mu -> tun.Close -> sync.WaitGroup.Wait
  goroutine M (wireguard read routine):
    FilteredDevice.Read -> filterOutbound -> udpHooksDrop ->
    filterDNSTraffic.func1 -> WGIface.GetDevice -> sync.Mutex.Lock

This surfaces as a 5 minute test timeout on the macOS Client/Unit
CI job (panic: test timed out after 5m0s, running tests:
TestDNSPermanent_updateUpstream).

Release w.mu before calling w.tun.Close(). The other Close steps
(wgProxyFactory.Free, waitUntilRemoved, Destroy) do not mutate any
fields guarded by w.mu beyond what Free() already does, so the lock
is not needed once the tun has started shutting down. A new unit test
in iface_close_test.go uses a fake WGTunDevice to reproduce the
deadlock deterministically without requiring CAP_NET_ADMIN.
2026-04-20 10:36:19 +02:00
Viktor Liu
95213f7157 [client] Use Match host+exec instead of Host+Match in SSH client config (#5903) 2026-04-20 10:24:11 +02:00
Viktor Liu
2e0e3a3601 [client] Replace exclusion routes with scoped default + IP_BOUND_IF on macOS (#5918) 2026-04-20 10:01:01 +02:00
mlsmaycon
3f4ef0031b [management] Skip full network map on Sync when peer state is unchanged
Introduce a peer-sync cache keyed by WireGuard pubkey that records the
NetworkMap.Serial and meta hash the server last delivered to each peer.
When a Sync request arrives from a non-Android peer whose cached serial
matches the current account serial and whose meta hash matches the last
delivery, short-circuit SyncAndMarkPeer and reply with a NetbirdConfig-only
SyncResponse mirroring the shape TimeBasedAuthSecretsManager already pushes
for TURN/Relay token rotation. The client keeps its existing network map
state and refreshes only control-plane credentials.

The fast path avoids GetAccountWithBackpressure, the full per-peer map
assembly, posture-check recomputation and the large encrypted payload on
every reconnect of a peer whose account is quiescent. Slow path remains
the source of truth for any real state change; every full-map send (initial
sync or streamed NetworkMap update) rewrites the cache, and every Login
deletes it so a fresh map is guaranteed after SSH key rotation, approval
changes or re-registration.

Backend-only: no proto changes and no client changes. Compatibility is
provided by the existing client handling of nil NetworkMap in handleSync
(every version from v0.20.0 on). Android is gated out at the server because
its readInitialSettings path calls GrpcClient.GetNetworkMap which errors on
nil map. The cache is wired through BaseServer.CacheStore() so it shares
the same Redis/in-memory backend as OneTimeTokenStore and PKCEVerifierStore.

Test coverage lands in four layers:
- Pure decision function (peer_serial_cache_decision_test.go)
- Cache wrapper with TTL + concurrency (peer_serial_cache_test.go)
- Response shape unit tests (sync_fast_path_response_test.go)
- In-process gRPC behavioural tests covering first sync, reconnect skip,
  android never-skip, meta change, login invalidation, and serial advance
  (management/server/sync_fast_path_test.go)
- Frozen SyncRequest wire-format fixtures for v0.20.0 / v0.40.0 / v0.60.0
  / current / android replayed against the in-process server
  (management/server/sync_legacy_wire_test.go + testdata fixtures)
2026-04-17 16:20:04 +02:00
Nicolas Frati
8ae8f2098f [management] chores: fix lint error on google workspace (#5907)
* chores: fix lint error on google workspace

* chores: updated google api dependency

* update google golang api sdk to latest
2026-04-16 20:02:09 +02:00
Viktor Liu
a39787d679 [infrastructure] Add CrowdSec LAPI container to self-hosted setup script (#5880) 2026-04-16 18:06:38 +02:00
Maycon Santos
53b04e512a [management] Reuse a single cache store across all management server consumers (#5889)
* Add support for legacy IDP cache environment variable

* Centralize cache store creation to reuse a single Redis connection pool

Each cache consumer (IDP cache, token store, PKCE store, secrets manager,
EDR validator) was independently calling NewStore, creating separate Redis
clients with their own connection pools — up to 1400 potential connections
from a single management server process.

Introduce a shared CacheStore() singleton on BaseServer that creates one
store at boot and injects it into all consumers. Consumer constructors now
receive a store.StoreInterface instead of creating their own.

For Redis mode, all consumers share one connection pool (1000 max conns).
For in-memory mode, all consumers share one GoCache instance.

* Update management-integrations module to latest version

* sync go.sum

* Export `GetAddrFromEnv` to allow reuse across packages

* Update management-integrations module version in go.mod and go.sum

* Update management-integrations module version in go.mod and go.sum
2026-04-16 16:04:53 +02:00
Viktor Liu
633dde8d1f [client] Reconnect conntrack netlink listener on error (#5885) 2026-04-16 22:30:36 +09:00
Michael Uray
7e4542adde fix(client): populate NetworkAddresses on iOS for posture checks (#5900)
The iOS GetInfo() function never populated NetworkAddresses, causing
the peer_network_range_check posture check to fail for all iOS clients.

This adds the same networkAddresses() call that macOS, Linux, Windows,
and FreeBSD already use.

Fixes: #3968
Fixes: #4657
2026-04-16 14:25:55 +02:00
Viktor Liu
d4c61ed38b [client] Add mangle FORWARD guard to prevent Docker DNAT bypass of ACL rules (#5697) 2026-04-16 14:02:52 +02:00
Viktor Liu
6b540d145c [client] Add --disable-networks flag to block network selection (#5896) 2026-04-16 14:02:31 +02:00
Bethuel Mmbaga
08f624507d [management] Enforce peer or peer groups requirement for network routers (#5894) 2026-04-16 13:12:19 +03:00
Viktor Liu
95bc01e48f [client] Allow clearing saved service env vars with --service-env "" (#5893) 2026-04-15 19:22:08 +02:00
Viktor Liu
0d86de47df [client] Add PCP support (#5219) 2026-04-15 11:43:16 +02:00
Viktor Liu
e804a705b7 [infrastructure] Update sign pipeline version to v0.1.2 (#5884) 2026-04-14 17:08:35 +02:00
Pascal Fischer
46fc8c9f65 [proxy] direct redirect to SSO (#5874) 2026-04-14 13:47:02 +02:00
Viktor Liu
d7ad908962 [misc] Add CI check for proto version string changes (#5854)
* Add CI check for proto version string changes

* Handle pagination and missing patch data in proto version check
2026-04-14 13:36:26 +02:00
Pascal Fischer
c5623307cc [management] add context cancel monitoring (#5879) 2026-04-14 12:49:18 +02:00
134 changed files with 6538 additions and 650 deletions

View File

@@ -426,8 +426,11 @@ jobs:
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Generate current sync wire fixtures
run: go run ./management/server/testdata/sync_request_wire/generate.go
- name: Test
run: |
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \

View File

@@ -0,0 +1,62 @@
name: Proto Version Check
on:
pull_request:
paths:
- "**/*.pb.go"
jobs:
check-proto-versions:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
per_page: 100,
});
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(
`Proto version strings changed in generated files.\n` +
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
`Regenerate with the matching tool versions.\n\n` +
details
);
return;
}
console.log('No proto version string changes detected');

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.1"
SIGN_PIPE_VER: "v0.1.2"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)

View File

@@ -8,6 +8,7 @@ import (
"os"
"slices"
"sync"
"time"
"golang.org/x/exp/maps"
@@ -15,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -26,6 +28,7 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -68,7 +71,30 @@ type Client struct {
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
}
// NewClient instantiate a new Client
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// Stop the internal client and free the resources
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
}
func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil {
cc := c.getConnectClient()
if cc == nil {
return fmt.Errorf("engine not running")
}
e := c.connectClient.Engine()
e := cc.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd)
}
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
}
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
cc := c.getConnectClient()
if cc == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
engine := cc.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient
client := c.getConnectClient()
if client == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -7,4 +7,5 @@ package android
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
CacheDir() string
}

View File

@@ -75,6 +75,7 @@ var (
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
networksDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",

View File

@@ -44,10 +44,13 @@ func init() {
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
`Use --service-env "" to clear all saved env vars. ` +
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
}
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}

View File

@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings")
}
if networksDisabled {
args = append(args, "--disable-networks")
}
return args
}

View File

@@ -28,6 +28,7 @@ type serviceParams struct {
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -78,11 +79,12 @@ func currentServiceParams() *serviceParams {
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
DisableNetworks: networksDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
if err == nil {
params.ServiceEnvVars = parsed
}
}
@@ -142,31 +144,46 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
updateSettingsDisabled = params.DisableUpdateSettings
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set with values, explicit values win on key
// conflict but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set to empty, all saved env vars are cleared.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
if len(params.ServiceEnvVars) > 0 {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
}
return
}
// Explicit env vars were provided: merge saved values underneath.
// Flag was explicitly set: parse what the user provided.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
// If the user passed an empty value (e.g. --service-env ""), clear all
// saved env vars rather than merging.
if len(explicit) == 0 {
serviceEnvVars = nil
return
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Merge saved values underneath explicit ones.
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict

View File

@@ -327,6 +327,41 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", ""))
saved := &serviceParams{
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
}
applyServiceEnvParams(cmd, saved)
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
}
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
params := currentServiceParams()
// After parsing, the empty string is skipped, resulting in an empty map.
// The map should still be set (not nil) so it overwrites saved values.
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
@@ -500,6 +535,7 @@ func fieldToGlobalVar(field string) string {
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {

View File

@@ -13,6 +13,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
ctx := context.Background()
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
t.Fatal(err)
}
@@ -127,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -152,7 +160,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
"", "", false, false)
"", "", false, false, false)
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -21,6 +21,10 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
@@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error {
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
@@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error {
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
@@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error {
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
@@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
if err := w.tun.Close(); err != nil {
// Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

View File

@@ -0,0 +1,113 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
}
if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}

View File

@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
cacheDir string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
}
@@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err)
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,7 +16,6 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -31,7 +30,6 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -234,6 +232,7 @@ type BundleGenerator struct {
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -256,6 +255,7 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
@@ -275,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -287,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
if err != nil {
return "", fmt.Errorf("create zip file: %w", err)
}
@@ -373,15 +374,8 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
if err := g.addPlatformLog(); err != nil {
log.Errorf("failed to add logs to debug bundle: %v", err)
}
if err := g.addUpdateLogs(); err != nil {

View File

@@ -0,0 +1,41 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -140,6 +140,7 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config
LogPath string
TempDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -569,7 +570,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -1095,6 +1096,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)

View File

@@ -55,6 +55,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -1634,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -1656,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
@@ -1665,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -22,4 +22,8 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager
FileDescriptor int32
StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
}

View File

@@ -7,7 +7,9 @@ import (
"fmt"
"net/netip"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
nfct "github.com/ti-mo/conntrack"
@@ -17,31 +19,64 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
const defaultChannelSize = 100
const (
defaultChannelSize = 100
reconnectInitInterval = 5 * time.Second
reconnectMaxInterval = 5 * time.Minute
reconnectRandomization = 0.5
)
// listener abstracts a netlink conntrack connection for testability.
type listener interface {
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
Close() error
}
// ConnTrack manages kernel-based conntrack events
type ConnTrack struct {
flowLogger nftypes.FlowLogger
iface nftypes.IFaceMapper
conn *nfct.Conn
conn listener
mux sync.Mutex
dial func() (listener, error)
instanceID uuid.UUID
started bool
done chan struct{}
sysctlModified bool
}
// DialFunc is a constructor for netlink conntrack connections.
type DialFunc func() (listener, error)
// Option configures a ConnTrack instance.
type Option func(*ConnTrack)
// WithDialer overrides the default netlink dialer, primarily for testing.
func WithDialer(dial DialFunc) Option {
return func(c *ConnTrack) {
c.dial = dial
}
}
func defaultDial() (listener, error) {
return nfct.Dial(nil)
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
return &ConnTrack{
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
ct := &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
started: false,
dial: defaultDial,
done: make(chan struct{}, 1),
}
for _, opt := range opts {
opt(ct)
}
return ct
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
@@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error {
c.EnableAccounting()
}
conn, err := nfct.Dial(nil)
conn, err := c.dial()
if err != nil {
c.RestoreAccounting()
return fmt.Errorf("dial conntrack: %w", err)
}
c.conn = conn
@@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
c.RestoreAccounting()
return fmt.Errorf("start conntrack listener: %w", err)
}
// Drain any stale stop signal from a previous cycle.
select {
case <-c.done:
default:
}
c.started = true
go c.receiverRoutine(events, errChan)
@@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error)
case event := <-events:
c.handleEvent(event)
case err := <-errChan:
log.Errorf("Error from conntrack event listener: %v", err)
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
if events, errChan = c.handleListenerError(err); events == nil {
return
}
return
case <-c.done:
return
}
}
}
// handleListenerError closes the failed connection and attempts to reconnect.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
log.Warnf("conntrack event listener failed: %v", err)
c.closeConn()
return c.reconnect()
}
func (c *ConnTrack) closeConn() {
c.mux.Lock()
defer c.mux.Unlock()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Debugf("close conntrack connection: %v", err)
}
c.conn = nil
}
}
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
bo := &backoff.ExponentialBackOff{
InitialInterval: reconnectInitInterval,
RandomizationFactor: reconnectRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: reconnectMaxInterval,
MaxElapsedTime: 0, // retry indefinitely
Clock: backoff.SystemClock,
}
bo.Reset()
for {
delay := bo.NextBackOff()
log.Infof("reconnecting conntrack listener in %s", delay)
select {
case <-c.done:
c.mux.Lock()
c.started = false
c.mux.Unlock()
return nil, nil
case <-time.After(delay):
}
conn, err := c.dial()
if err != nil {
log.Warnf("reconnect conntrack dial: %v", err)
continue
}
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
log.Warnf("reconnect conntrack listen: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
continue
}
c.mux.Lock()
if !c.started {
// Stop() ran while we were reconnecting.
c.mux.Unlock()
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
return nil, nil
}
c.conn = conn
c.mux.Unlock()
log.Infof("conntrack listener reconnected successfully")
return events, errChan
}
}
// Stop stops the connection tracking. This method is idempotent.
func (c *ConnTrack) Stop() {
c.mux.Lock()
@@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
select {
case c.done <- struct{}{}:
default:
}
if !c.started {
return nil
}
select {
case c.done <- struct{}{}:
default:
}
c.started = false
var closeErr error
if c.conn != nil {
err := c.conn.Close()
closeErr = c.conn.Close()
c.conn = nil
c.started = false
}
c.RestoreAccounting()
c.RestoreAccounting()
if err != nil {
return fmt.Errorf("close conntrack: %w", err)
}
if closeErr != nil {
return fmt.Errorf("close conntrack: %w", closeErr)
}
return nil

View File

@@ -0,0 +1,224 @@
//go:build linux && !android
package conntrack
import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nfct "github.com/ti-mo/conntrack"
"github.com/ti-mo/netfilter"
)
type mockListener struct {
errChan chan error
closed atomic.Bool
closedCh chan struct{}
}
func newMockListener() *mockListener {
return &mockListener{
errChan: make(chan error, 1),
closedCh: make(chan struct{}),
}
}
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
return m.errChan, nil
}
func (m *mockListener) Close() error {
if m.closed.CompareAndSwap(false, true) {
close(m.closedCh)
}
return nil
}
func TestReconnectAfterError(t *testing.T) {
first := newMockListener()
second := newMockListener()
third := newMockListener()
listeners := []*mockListener{first, second, third}
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := int(callCount.Add(1)) - 1
return listeners[n], nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Inject an error on the first listener.
first.errChan <- assert.AnError
// Wait for reconnect to complete.
require.Eventually(t, func() bool {
return callCount.Load() >= 2
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
// The first connection must have been closed.
select {
case <-first.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("first connection was not closed")
}
// Verify the receiver is still running by injecting and handling a second error.
second.errChan <- assert.AnError
require.Eventually(t, func() bool {
return callCount.Load() >= 3
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
ct.Stop()
}
func TestStopDuringReconnectBackoff(t *testing.T) {
mock := newMockListener()
ct := New(nil, nil, WithDialer(func() (listener, error) {
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Trigger an error so the receiver enters reconnect.
mock.errChan <- assert.AnError
// Wait for the error handler to close the old listener before calling Stop.
select {
case <-mock.closedCh:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for reconnect to start")
}
// Stop while reconnecting.
ct.Stop()
ct.mux.Lock()
assert.False(t, ct.started, "started should be false after Stop")
assert.Nil(t, ct.conn, "conn should be nil after Stop")
ct.mux.Unlock()
}
func TestStopRaceWithReconnectDial(t *testing.T) {
first := newMockListener()
dialStarted := make(chan struct{})
dialProceed := make(chan struct{})
second := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
// Second dial: signal that we're in progress, wait for test to call Stop.
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Trigger error to enter reconnect.
first.errChan <- assert.AnError
// Wait for reconnect's second dial to begin.
select {
case <-dialStarted:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for reconnect dial")
}
// Stop while dial is in progress (conn is nil at this point).
ct.Stop()
// Let the dial complete. reconnect should detect started==false and close the new conn.
close(dialProceed)
// The second connection should be closed (not leaked).
select {
case <-second.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("second connection was leaked after Stop")
}
ct.mux.Lock()
assert.False(t, ct.started)
assert.Nil(t, ct.conn)
ct.mux.Unlock()
}
func TestCloseRaceWithReconnectDial(t *testing.T) {
first := newMockListener()
dialStarted := make(chan struct{})
dialProceed := make(chan struct{})
second := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
first.errChan <- assert.AnError
select {
case <-dialStarted:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for reconnect dial")
}
// Close while dial is in progress (conn is nil).
require.NoError(t, ct.Close())
close(dialProceed)
// The second connection should be closed (not leaked).
select {
case <-second.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("second connection was leaked after Close")
}
ct.mux.Lock()
assert.False(t, ct.started)
assert.Nil(t, ct.conn)
ct.mux.Unlock()
}
func TestStartIsIdempotent(t *testing.T) {
mock := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
callCount.Add(1)
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Second Start should be a no-op.
err = ct.Start(false)
require.NoError(t, err)
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
ct.Stop()
}

View File

@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
forceRelay := IsForceRelayed()
if !forceRelay {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
if !forceRelay {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
}
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()
if conn.workerICE != nil {
conn.workerICE.Close()
}
if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn()
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
if conn.workerICE != nil {
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusConnecting
}
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
// would be better to protect this with a mutex, but it could cause deadlock with Close function
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
//
// The result is a tri-state:
// - ConnStatusConnected: all available transports are up
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
// - ConnStatusDisconnected: no working transport
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
defer func() {
if !connected {
if status == guard.ConnStatusDisconnected {
conn.logTraceConnState()
}
}()
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
iceWorkerCreated := conn.workerICE != nil
var iceInProgress bool
if iceWorkerCreated {
iceInProgress = conn.workerICE.InProgress()
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false
}
}
return true
return evalConnStatus(connStatusInputs{
forceRelay: IsForceRelayed(),
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
iceWorkerCreated: iceWorkerCreated,
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
iceInProgress: iceInProgress,
})
}
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
// "Relay up and needed" — the peer uses relay and the transport is connected.
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
if in.forceRelay {
return boolToConnStatus(relayUsedAndUp)
}
// Remote peer doesn't support ICE, or we haven't created the worker yet:
// relay is the only possible transport.
if !in.remoteSupportsICE || !in.iceWorkerCreated {
return boolToConnStatus(relayUsedAndUp)
}
// ICE counts as "up" when the status is anything other than Disconnected, OR
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
iceUp := in.iceStatusConnecting || in.iceInProgress
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
relayOK := !in.peerUsesRelay || in.relayConnected
switch {
case iceUp && relayOK:
return guard.ConnStatusConnected
case relayUsedAndUp:
// Relay is up but ICE is down — partially connected.
return guard.ConnStatusPartiallyConnected
default:
return guard.ConnStatusDisconnected
}
}
func boolToConnStatus(connected bool) guard.ConnStatus {
if connected {
return guard.ConnStatusConnected
}
return guard.ConnStatusDisconnected
}

View File

@@ -13,6 +13,20 @@ const (
StatusConnected
)
// connStatusInputs is the primitive-valued snapshot of the state that drives the
// tri-state connection classification. Extracted so the decision logic can be unit-tested
// without constructing full Worker/Handshaker objects.
type connStatusInputs struct {
forceRelay bool // NB_FORCE_RELAY or JS/WASM
peerUsesRelay bool // remote peer advertises relay support AND local has relay
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
remoteSupportsICE bool // remote peer sent ICE credentials
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
iceStatusConnecting bool // statusICE is anything other than Disconnected
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection
type ConnStatus int32

View File

@@ -0,0 +1,201 @@
package peer
import (
"testing"
"github.com/netbirdio/netbird/client/internal/peer/guard"
)
func TestEvalConnStatus_ForceRelay(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "force relay, peer uses relay, relay up",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: true,
},
want: guard.ConnStatusConnected,
},
{
name: "force relay, peer uses relay, relay down",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: false,
},
want: guard.ConnStatusDisconnected,
},
{
name: "force relay, peer does NOT use relay - disconnected forever",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: false,
relayConnected: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "remote does not support ICE, peer uses relay, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer uses relay, relay down",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE worker not yet created, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: true,
iceWorkerCreated: false,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer does not use relay",
in: connStatusInputs{
peerUsesRelay: false,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
base := connStatusInputs{
remoteSupportsICE: true,
iceWorkerCreated: true,
}
tests := []struct {
name string
mutator func(*connStatusInputs)
want guard.ConnStatus
}{
{
name: "ICE connected, relay connected, peer uses relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE connected, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE InProgress only, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.iceStatusConnecting = false
in.iceInProgress = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE down, relay up, peer uses relay -> partial",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusPartiallyConnected,
},
{
name: "ICE down, peer does NOT use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = false
in.iceStatusConnecting = true
},
// relayOK = false (peer uses relay but it's down), iceUp = true
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
// falls into default: Disconnected.
want: guard.ConnStatusDisconnected,
},
{
name: "ICE down, relay up but peer does not use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = true // not actually used since peer doesn't rely on it
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
in := base
tc.mutator(&in)
if got := evalConnStatus(in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
}
})
}
}

View File

@@ -10,7 +10,7 @@ const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
func IsForceRelayed() bool {
if runtime.GOOS == "js" {
return true
}

View File

@@ -8,7 +8,19 @@ import (
log "github.com/sirupsen/logrus"
)
type isConnectedFunc func() bool
// ConnStatus represents the connection state as seen by the guard.
type ConnStatus int
const (
// ConnStatusDisconnected means neither ICE nor Relay is connected.
ConnStatusDisconnected ConnStatus = iota
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
ConnStatusPartiallyConnected
// ConnStatusConnected means all required connections are established.
ConnStatusConnected
)
type connStatusFunc func() ConnStatus
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
// - ICE candidate changes
type Guard struct {
log *log.Entry
isConnectedOnAllWay isConnectedFunc
isConnectedOnAllWay connStatusFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{}
}
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
log: log,
isConnectedOnAllWay: isConnectedFn,
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
}
}
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
//
// Behavior depends on the connection state reported by isConnectedOnAllWay:
// - Connected: no action, the peer is fully reachable.
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
//
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
tickerChannel := ticker.C
iceState := &iceRetryState{log: g.log}
defer iceState.reset()
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
case <-tickerChannel:
switch g.isConnectedOnAllWay() {
case ConnStatusConnected:
// all good, nothing to do
case ConnStatusDisconnected:
callback()
case ConnStatusPartiallyConnected:
if iceState.shouldRetry() {
callback()
} else {
iceState.enterHourlyMode()
ticker.Stop()
tickerChannel = iceState.hourlyC()
}
}
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,

View File

@@ -0,0 +1,61 @@
package guard
import (
"time"
log "github.com/sirupsen/logrus"
)
const (
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
maxICERetries = 3
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
iceRetryInterval = 1 * time.Hour
)
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
// After maxICERetries attempts it switches to a periodic hourly retry.
type iceRetryState struct {
log *log.Entry
retries int
hourly *time.Ticker
}
func (s *iceRetryState) reset() {
s.retries = 0
if s.hourly != nil {
s.hourly.Stop()
s.hourly = nil
}
}
// shouldRetry reports whether the caller should send another ICE offer on this tick.
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
// to the hourly ticker via enterHourlyMode + hourlyC.
func (s *iceRetryState) shouldRetry() bool {
if s.hourly != nil {
s.log.Debugf("hourly ICE retry attempt")
return true
}
s.retries++
if s.retries <= maxICERetries {
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
return true
}
return false
}
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
func (s *iceRetryState) enterHourlyMode() {
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
s.hourly = time.NewTicker(iceRetryInterval)
}
func (s *iceRetryState) hourlyC() <-chan time.Time {
if s.hourly == nil {
return nil
}
return s.hourly.C
}

View File

@@ -0,0 +1,103 @@
package guard
import (
"testing"
log "github.com/sirupsen/logrus"
)
func newTestRetryState() *iceRetryState {
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
}
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
s := newTestRetryState()
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
}
}
}
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries; i++ {
_ = s.shouldRetry()
}
if s.shouldRetry() {
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
}
}
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
s := newTestRetryState()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
}
}
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
defer s.reset()
if s.hourlyC() == nil {
t.Fatalf("hourlyC returned nil after enterHourlyMode")
}
}
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
s := newTestRetryState()
s.enterHourlyMode()
defer s.reset()
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false in hourly mode, want true")
}
// Subsequent calls also return true — we keep retrying on each hourly tick.
if !s.shouldRetry() {
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
}
}
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
s.reset()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel after reset")
}
if s.retries != 0 {
t.Fatalf("retries = %d after reset, want 0", s.retries)
}
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
}
}
}
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
s := newTestRetryState()
s.reset()
s.reset() // second call must not panic or re-stop a nil ticker
if s.hourlyC() != nil {
t.Fatalf("hourlyC non-nil after double reset")
}
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw
}
func (w *SRWatcher) Start() {
func (w *SRWatcher) Start(disableICEMonitor bool) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
if !disableICEMonitor {
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
}
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
@@ -43,6 +44,10 @@ type OfferAnswer struct {
SessionID *ICESessionID
}
func (o *OfferAnswer) hasICECredentials() bool {
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
}
type Handshaker struct {
mu sync.Mutex
log *log.Entry
@@ -59,6 +64,10 @@ type Handshaker struct {
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -66,7 +75,7 @@ type Handshaker struct {
}
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
return &Handshaker{
h := &Handshaker{
log: log,
config: config,
signaler: signaler,
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
}
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
for {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer)
}
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue
}
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer)
}
case <-ctx.Done():
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
}
if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
}
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
return answer
}
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.hasICECredentials()
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
var sessionIDBytes []byte
if offerAnswer.SessionID != nil {
var err error
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,

View File

@@ -8,18 +8,27 @@ import (
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
)
func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
return parseBoolEnv(envDisableNATMapper)
}
func isHealthCheckDisabled() bool {
return parseBoolEnv(envDisablePCPHealthCheck)
}
func parseBoolEnv(key string) bool {
val := os.Getenv(key)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
log.Warnf("failed to parse %s: %v", key, err)
return false
}
return disabled

View File

@@ -12,12 +12,15 @@ import (
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
const (
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
defaultMappingTTL = 2 * time.Hour
healthCheckInterval = 1 * time.Minute
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
@@ -154,7 +157,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := nat.DiscoverGateway(discoverCtx)
gateway, err := discoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
@@ -189,7 +192,6 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
@@ -208,27 +210,87 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
// Permanent mappings don't expire, just wait for cancellation
// but still run health checks for PCP gateways.
m.permanentLeaseLoop(ctx, gateway)
return
}
ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
renewTicker := time.NewTicker(ttl / 2)
healthTicker := time.NewTicker(healthCheckInterval)
defer renewTicker.Stop()
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-renewTicker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
case <-healthTicker.C:
if m.checkHealthAndRecreate(ctx, gateway) {
renewTicker.Reset(ttl / 2)
}
}
}
}
func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
healthTicker := time.NewTicker(healthCheckInterval)
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-healthTicker.C:
m.checkHealthAndRecreate(ctx, gateway)
}
}
}
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
if isHealthCheckDisabled() {
return false
}
m.mappingLock.Lock()
hasMapping := m.mapping != nil
m.mappingLock.Unlock()
if !hasMapping {
return false
}
pcpNAT, ok := gateway.(*pcp.NAT)
if !ok {
return false
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
if err != nil {
log.Debugf("PCP health check failed: %v", err)
return false
}
if serverRestarted {
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
if err := m.renewMapping(ctx, gateway); err != nil {
log.Errorf("failed to recreate port mapping after server restart: %v", err)
return false
}
return true
}
return false
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

View File

@@ -0,0 +1,408 @@
package pcp
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultTimeout = 3 * time.Second
responseBufferSize = 128
// RFC 6887 Section 8.1.1 retry timing
initialRetryDelay = 3 * time.Second
maxRetryDelay = 1024 * time.Second
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
)
// Client is a PCP protocol client.
// All methods are safe for concurrent use.
type Client struct {
gateway netip.Addr
timeout time.Duration
mu sync.Mutex
// localIP caches the resolved local IP address.
localIP netip.Addr
// lastEpoch is the last observed server epoch value.
lastEpoch uint32
// epochTime tracks when lastEpoch was received for state loss detection.
epochTime time.Time
// externalIP caches the external IP from the last successful MAP response.
externalIP netip.Addr
// epochStateLost is set when epoch indicates server restart.
epochStateLost bool
}
// NewClient creates a new PCP client for the gateway at the given IP.
func NewClient(gateway net.IP) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: defaultTimeout,
}
}
// NewClientWithTimeout creates a new PCP client with a custom timeout.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: timeout,
}
}
// SetLocalIP sets the local IP address to use in PCP requests.
func (c *Client) SetLocalIP(ip net.IP) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Debugf("invalid local IP: %v", ip)
}
c.mu.Lock()
c.localIP = addr.Unmap()
c.mu.Unlock()
}
// Gateway returns the gateway IP address.
func (c *Client) Gateway() net.IP {
return c.gateway.AsSlice()
}
// Announce sends a PCP ANNOUNCE request to discover PCP support.
// Returns the server's epoch time on success.
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
localIP, err := c.getLocalIP()
if err != nil {
return 0, fmt.Errorf("get local IP: %w", err)
}
req := buildAnnounceRequest(localIP)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return 0, fmt.Errorf("send announce: %w", err)
}
parsed, err := parseResponse(resp)
if err != nil {
return 0, fmt.Errorf("parse announce response: %w", err)
}
if parsed.ResultCode != ResultSuccess {
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
}
c.mu.Lock()
if c.updateEpochLocked(parsed.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.mu.Unlock()
return parsed.Epoch, nil
}
// AddPortMapping requests a port mapping from the PCP server.
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
}
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
// Use lifetime <= 0 to delete a mapping.
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
var extIP netip.Addr
if suggestedExtIP != nil {
var ok bool
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
if !ok {
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
}
extIP = extIP.Unmap()
}
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
}
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
localIP, err := c.getLocalIP()
if err != nil {
return nil, fmt.Errorf("get local IP: %w", err)
}
proto, err := protocolNumber(protocol)
if err != nil {
return nil, fmt.Errorf("parse protocol: %w", err)
}
var nonce [12]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("generate nonce: %w", err)
}
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
// default for positive durations that round to 0 seconds.
var lifetimeSec uint32
if lifetime > 0 {
lifetimeSec = uint32(lifetime.Seconds())
if lifetimeSec == 0 {
lifetimeSec = DefaultLifetime
}
}
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("send map request: %w", err)
}
mapResp, err := parseMapResponse(resp)
if err != nil {
return nil, fmt.Errorf("parse map response: %w", err)
}
if mapResp.Nonce != nonce {
return nil, fmt.Errorf("nonce mismatch in response")
}
if mapResp.Protocol != proto {
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
}
if mapResp.InternalPort != uint16(internalPort) {
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
}
if mapResp.ResultCode != ResultSuccess {
return nil, &Error{
Code: mapResp.ResultCode,
Message: ResultCodeString(mapResp.ResultCode),
}
}
c.mu.Lock()
if c.updateEpochLocked(mapResp.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.cacheExternalIPLocked(mapResp.ExternalIP)
c.mu.Unlock()
return mapResp, nil
}
// DeletePortMapping removes a port mapping by requesting zero lifetime.
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
var pcpErr *Error
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
return nil
}
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// GetExternalAddress returns the external IP address.
// First checks for a cached value from previous MAP responses.
// If not cached, creates a short-lived mapping to discover the external IP.
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
c.mu.Lock()
if c.externalIP.IsValid() {
ip := c.externalIP.AsSlice()
c.mu.Unlock()
return ip, nil
}
c.mu.Unlock()
// Use an ephemeral port in the dynamic range (49152-65535).
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
// Use minimal lifetime (1 second) for discovery.
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
if err != nil {
return nil, fmt.Errorf("create temporary mapping: %w", err)
}
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
log.Debugf("cleanup temporary PCP mapping: %v", err)
}
return resp.ExternalIP.AsSlice(), nil
}
// LastEpoch returns the last observed server epoch value.
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
func (c *Client) LastEpoch() uint32 {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastEpoch
}
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
func (c *Client) EpochStateLost() bool {
c.mu.Lock()
defer c.mu.Unlock()
lost := c.epochStateLost
c.epochStateLost = false
return lost
}
// updateEpoch updates the epoch tracking and detects potential state loss.
// Returns true if state loss was detected (server likely restarted).
// Caller must hold c.mu.
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
now := time.Now()
stateLost := false
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
// client_delta = time since last response
// server_delta = epoch change since last response
// Invalid if: client_delta+2 < server_delta - server_delta/16
// OR: server_delta+2 < client_delta - client_delta/16
// The +2 handles quantization, /16 (6.25%) handles clock drift.
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
serverDelta := newEpoch - c.lastEpoch
// Check for epoch going backwards or jumping unexpectedly.
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
if clientDelta+2 < serverDelta-(serverDelta/16) ||
serverDelta+2 < clientDelta-(clientDelta/16) {
stateLost = true
c.epochStateLost = true
}
}
c.lastEpoch = newEpoch
c.epochTime = now
return stateLost
}
// cacheExternalIP stores the external IP from a successful MAP response.
// Caller must hold c.mu.
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
if ip.IsValid() && !ip.IsUnspecified() {
c.externalIP = ip
}
}
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
var lastErr error
delay := initialRetryDelay
for range maxRetries {
resp, err := c.sendOnce(ctx, addr, req)
if err == nil {
return resp, nil
}
lastErr = err
if ctx.Err() != nil {
return nil, ctx.Err()
}
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
// RAND is random between -0.1 and +0.1
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelayWithJitter(delay)):
}
delay = min(delay*2, maxRetryDelay)
}
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
}
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
func retryDelayWithJitter(d time.Duration) time.Duration {
var b [1]byte
_, _ = rand.Read(b[:])
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
jitter := (float64(b[0])/255.0)*0.2 - 0.1
return time.Duration(float64(d) * (1 + jitter))
}
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("close UDP connection: %v", err)
}
}()
timeout := c.timeout
if deadline, ok := ctx.Deadline(); ok {
if remaining := time.Until(deadline); remaining < timeout {
timeout = remaining
}
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
if _, err := conn.WriteToUDP(req, addr); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
resp := make([]byte, responseBufferSize)
n, from, err := conn.ReadFromUDP(resp)
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
// RFC 6887 §8.3: Validate response came from expected PCP server.
if !from.IP.Equal(addr.IP) {
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
}
return resp[:n], nil
}
func (c *Client) getLocalIP() (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.localIP.IsValid() {
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
}
return c.localIP, nil
}
func protocolNumber(protocol string) (uint8, error) {
switch protocol {
case "udp", "UDP":
return ProtoUDP, nil
case "tcp", "TCP":
return ProtoTCP, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// Error represents a PCP error response.
type Error struct {
Code uint8
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
}

View File

@@ -0,0 +1,187 @@
package pcp
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddrConversion(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4", netip.MustParseAddr("192.168.1.100")},
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
{"IPv6", netip.MustParseAddr("2001:db8::1")},
{"IPv6 loopback", netip.MustParseAddr("::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b16 := addrTo16(tt.addr)
recovered := addrFrom16(b16)
assert.Equal(t, tt.addr, recovered, "address should round-trip")
})
}
}
func TestBuildAnnounceRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
req := buildAnnounceRequest(clientIP)
require.Len(t, req, headerSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
// Check client IP is properly encoded as IPv4-mapped IPv6
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
assert.Equal(t, byte(192), req[20], "IP octet 1")
assert.Equal(t, byte(168), req[21], "IP octet 2")
assert.Equal(t, byte(1), req[22], "IP octet 3")
assert.Equal(t, byte(100), req[23], "IP octet 4")
}
func TestBuildMapRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
require.Len(t, req, mapRequestSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpMap), req[1], "opcode")
// Lifetime at bytes 4-7
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
// Nonce at bytes 24-35
assert.Equal(t, nonce[:], req[24:36], "nonce")
// Protocol at byte 36
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
// Internal port at bytes 40-41
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
// External port at bytes 42-43
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
}
func TestParseResponse(t *testing.T) {
// Construct a valid ANNOUNCE response
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce | OpReply
// Result code = 0 (success)
// Lifetime = 0
// Epoch = 12345
resp[8] = 0
resp[9] = 0
resp[10] = 0x30
resp[11] = 0x39
parsed, err := parseResponse(resp)
require.NoError(t, err)
assert.Equal(t, uint8(Version), parsed.Version)
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
assert.Equal(t, uint32(12345), parsed.Epoch)
}
func TestParseResponseErrors(t *testing.T) {
t.Run("too short", func(t *testing.T) {
_, err := parseResponse([]byte{1, 2, 3})
assert.Error(t, err)
})
t.Run("wrong version", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = 1 // Wrong version
resp[1] = OpReply
_, err := parseResponse(resp)
assert.Error(t, err)
})
t.Run("missing reply bit", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce // Missing OpReply bit
_, err := parseResponse(resp)
assert.Error(t, err)
})
}
func TestResultCodeString(t *testing.T) {
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
}
func TestProtocolNumber(t *testing.T) {
proto, err := protocolNumber("udp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
proto, err = protocolNumber("tcp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoTCP), proto)
proto, err = protocolNumber("UDP")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
_, err = protocolNumber("icmp")
assert.Error(t, err)
}
func TestClientCreation(t *testing.T) {
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
client := NewClient(gateway)
assert.Equal(t, net.IP(gateway), client.Gateway())
assert.Equal(t, defaultTimeout, client.timeout)
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
}
func TestNATType(t *testing.T) {
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
assert.Equal(t, "PCP", n.Type())
}
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
func TestClientIntegration(t *testing.T) {
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
client := NewClient(gateway)
client.SetLocalIP(localIP)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Test ANNOUNCE
epoch, err := client.Announce(ctx)
require.NoError(t, err)
t.Logf("Server epoch: %d", epoch)
// Test MAP
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
require.NoError(t, err)
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
// Cleanup
err = client.DeletePortMapping(ctx, "udp", 51820)
require.NoError(t, err)
}

View File

@@ -0,0 +1,209 @@
package pcp
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/libp2p/go-nat"
"github.com/libp2p/go-netroute"
)
var _ nat.NAT = (*NAT)(nil)
// NAT implements the go-nat NAT interface using PCP.
// Supports dual-stack (IPv4 and IPv6) when available.
// All methods are safe for concurrent use.
//
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
// and needs to be recreated with the new address.
type NAT struct {
client *Client
mu sync.RWMutex
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
client6 *Client
// localIP6 caches the local IPv6 address used for PCP requests.
localIP6 netip.Addr
}
// NewNAT creates a new NAT instance backed by PCP.
func NewNAT(gateway, localIP net.IP) *NAT {
client := NewClient(gateway)
client.SetLocalIP(localIP)
return &NAT{
client: client,
}
}
// Type returns "PCP" as the NAT type.
func (n *NAT) Type() string {
return "PCP"
}
// GetDeviceAddress returns the gateway IP address.
func (n *NAT) GetDeviceAddress() (net.IP, error) {
return n.client.Gateway(), nil
}
// GetExternalAddress returns the external IP address.
func (n *NAT) GetExternalAddress() (net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return n.client.GetExternalAddress(ctx)
}
// GetInternalAddress returns the local IP address used to communicate with the gateway.
func (n *NAT) GetInternalAddress() (net.IP, error) {
addr, err := n.client.getLocalIP()
if err != nil {
return nil, err
}
return addr.AsSlice(), nil
}
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
if err != nil {
return 0, fmt.Errorf("add mapping: %w", err)
}
n.mu.RLock()
client6 := n.client6
localIP6 := n.localIP6
n.mu.RUnlock()
if client6 == nil {
return int(resp.ExternalPort), nil
}
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
return int(resp.ExternalPort), nil
}
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
return int(resp.ExternalPort), nil
}
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
n.mu.RLock()
client6 := n.client6
n.mu.RUnlock()
if client6 != nil {
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
}
}
if err != nil {
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
epoch, err = n.client.Announce(ctx)
if err != nil {
return 0, false, fmt.Errorf("announce: %w", err)
}
return epoch, n.client.EpochStateLost(), nil
}
// DiscoverPCP attempts to discover a PCP-capable gateway.
// Returns a NAT interface if PCP is supported, or an error otherwise.
// Discovers both IPv4 and IPv6 gateways when available.
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
gateway, localIP, err := getDefaultGateway()
if err != nil {
return nil, fmt.Errorf("get default gateway: %w", err)
}
client := NewClient(gateway)
client.SetLocalIP(localIP)
if _, err := client.Announce(ctx); err != nil {
return nil, fmt.Errorf("PCP announce: %w", err)
}
result := &NAT{client: client}
discoverIPv6(ctx, result)
return result, nil
}
func discoverIPv6(ctx context.Context, result *NAT) {
gateway6, localIP6, err := getDefaultGateway6()
if err != nil {
log.Debugf("IPv6 gateway discovery failed: %v", err)
return
}
client6 := NewClient(gateway6)
client6.SetLocalIP(localIP6)
if _, err := client6.Announce(ctx); err != nil {
log.Debugf("PCP IPv6 announce failed: %v", err)
return
}
addr, ok := netip.AddrFromSlice(localIP6)
if !ok {
log.Debugf("invalid IPv6 local IP: %v", localIP6)
return
}
result.mu.Lock()
result.client6 = client6
result.localIP6 = addr
result.mu.Unlock()
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
}
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv4zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv6zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}

View File

@@ -0,0 +1,225 @@
// Package pcp implements the Port Control Protocol (RFC 6887).
//
// # Implemented Features
//
// - ANNOUNCE opcode: Discovers PCP server support
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
// - Nonce validation: Prevents response spoofing
// - Epoch tracking: Detects server restarts per Section 8.5
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
//
// # Not Implemented
//
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
// - THIRD_PARTY option: For managing mappings on behalf of other devices
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
// - FILTER option: To restrict remote peer addresses
//
// These optional features are omitted because the primary use case is simple
// port forwarding for WireGuard, which only requires MAP with default behavior.
package pcp
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
// Version is the PCP protocol version (RFC 6887).
Version = 2
// Port is the standard PCP server port.
Port = 5351
// DefaultLifetime is the default requested mapping lifetime in seconds.
DefaultLifetime = 7200 // 2 hours
// Header sizes
headerSize = 24
mapPayloadSize = 36
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
)
// Opcodes
const (
OpAnnounce = 0
OpMap = 1
OpPeer = 2
OpReply = 0x80 // OR'd with opcode in responses
)
// Protocol numbers for MAP requests
const (
ProtoUDP = 17
ProtoTCP = 6
)
// Result codes (RFC 6887 Section 7.4)
const (
ResultSuccess = 0
ResultUnsuppVersion = 1
ResultNotAuthorized = 2
ResultMalformedRequest = 3
ResultUnsuppOpcode = 4
ResultUnsuppOption = 5
ResultMalformedOption = 6
ResultNetworkFailure = 7
ResultNoResources = 8
ResultUnsuppProtocol = 9
ResultUserExQuota = 10
ResultCannotProvideExt = 11
ResultAddressMismatch = 12
ResultExcessiveRemotePeers = 13
)
// ResultCodeString returns a human-readable string for a result code.
func ResultCodeString(code uint8) string {
switch code {
case ResultSuccess:
return "SUCCESS"
case ResultUnsuppVersion:
return "UNSUPP_VERSION"
case ResultNotAuthorized:
return "NOT_AUTHORIZED"
case ResultMalformedRequest:
return "MALFORMED_REQUEST"
case ResultUnsuppOpcode:
return "UNSUPP_OPCODE"
case ResultUnsuppOption:
return "UNSUPP_OPTION"
case ResultMalformedOption:
return "MALFORMED_OPTION"
case ResultNetworkFailure:
return "NETWORK_FAILURE"
case ResultNoResources:
return "NO_RESOURCES"
case ResultUnsuppProtocol:
return "UNSUPP_PROTOCOL"
case ResultUserExQuota:
return "USER_EX_QUOTA"
case ResultCannotProvideExt:
return "CANNOT_PROVIDE_EXTERNAL"
case ResultAddressMismatch:
return "ADDRESS_MISMATCH"
case ResultExcessiveRemotePeers:
return "EXCESSIVE_REMOTE_PEERS"
default:
return fmt.Sprintf("UNKNOWN(%d)", code)
}
}
// Response represents a parsed PCP response header.
type Response struct {
Version uint8
Opcode uint8
ResultCode uint8
Lifetime uint32
Epoch uint32
}
// MapResponse contains the full response to a MAP request.
type MapResponse struct {
Response
Nonce [12]byte
Protocol uint8
InternalPort uint16
ExternalPort uint16
ExternalIP netip.Addr
}
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
func addrTo16(addr netip.Addr) [16]byte {
if addr.Is4() {
return netip.AddrFrom4(addr.As4()).As16()
}
return addr.As16()
}
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
func addrFrom16(b [16]byte) netip.Addr {
return netip.AddrFrom16(b).Unmap()
}
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
func buildAnnounceRequest(clientIP netip.Addr) []byte {
req := make([]byte, headerSize)
req[0] = Version
req[1] = OpAnnounce
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
return req
}
// buildMapRequest creates a PCP MAP request packet.
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
req := make([]byte, mapRequestSize)
// Header
req[0] = Version
req[1] = OpMap
binary.BigEndian.PutUint32(req[4:8], lifetime)
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
// MAP payload
copy(req[24:36], nonce[:])
req[36] = protocol
binary.BigEndian.PutUint16(req[40:42], internalPort)
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
if suggestedExtIP.IsValid() {
extMapped := addrTo16(suggestedExtIP)
copy(req[44:60], extMapped[:])
}
return req
}
// parseResponse parses the common PCP response header.
func parseResponse(data []byte) (*Response, error) {
if len(data) < headerSize {
return nil, fmt.Errorf("response too short: %d bytes", len(data))
}
resp := &Response{
Version: data[0],
Opcode: data[1],
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
Lifetime: binary.BigEndian.Uint32(data[4:8]),
Epoch: binary.BigEndian.Uint32(data[8:12]),
}
if resp.Version != Version {
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
}
if resp.Opcode&OpReply == 0 {
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
}
return resp, nil
}
// parseMapResponse parses a complete MAP response.
func parseMapResponse(data []byte) (*MapResponse, error) {
if len(data) < mapRequestSize {
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
}
resp, err := parseResponse(data)
if err != nil {
return nil, fmt.Errorf("parse header: %w", err)
}
mapResp := &MapResponse{
Response: *resp,
Protocol: data[36],
InternalPort: binary.BigEndian.Uint16(data[40:42]),
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
ExternalIP: addrFrom16([16]byte(data[44:60])),
}
copy(mapResp.Nonce[:], data[24:36])
return mapResp, nil
}

View File

@@ -0,0 +1,63 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
// discoverGateway is the function used for NAT gateway discovery.
// It can be replaced in tests to avoid real network operations.
// Tries PCP first, then falls back to NAT-PMP/UPnP.
var discoverGateway = defaultDiscoverGateway
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
pcpGateway, err := pcp.DiscoverPCP(ctx)
if err == nil {
return pcpGateway, nil
}
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
return nat.DiscoverGateway(ctx)
}
// State is persisted only for crash recovery cleanup
type State struct {
InternalPort uint16 `json:"internal_port,omitempty"`
Protocol string `json:"protocol,omitempty"`
}
func (s *State) Name() string {
return "port_forward_state"
}
// Cleanup implements statemanager.CleanableState for crash recovery
func (s *State) Cleanup() error {
if s.InternalPort == 0 {
return nil
}
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
gateway, err := discoverGateway(ctx)
if err != nil {
// Discovery failure is not an error - gateway may not exist
log.Debugf("cleanup: no gateway found: %v", err)
return nil
}
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
return fmt.Errorf("delete port mapping: %w", err)
}
return nil
}

View File

@@ -0,0 +1,10 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -0,0 +1,241 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
@@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{})
@@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()
if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
log.Errorf("Failed to get routes: %v", err)
return false, netip.Prefix{}
}

View File

@@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil

View File

@@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,10 +48,6 @@ func EnableIPForwarding() error {
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,6 +25,9 @@ import (
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
)
var routeProtoFlag int
@@ -41,26 +44,42 @@ func init() {
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
@@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix)
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return fmt.Errorf("build route message: %w", err)
}
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
@@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil
}
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
// kernel's matching reply, retrying transient failures until budget elapses.
// Callers do not need to manage sockets or seq numbers themselves.
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}()
tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
return backoff.Permanent(fmt.Errorf("read: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
continue
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
// uniquely identifies our own reply among broadcast traffic.
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
continue
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
if hdr.Errno != 0 {
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
}
@@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows
//go:build !linux && !windows && !darwin
package net

View File

@@ -1,24 +0,0 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build windows
//go:build (darwin && !ios) || windows
package net
@@ -24,17 +24,22 @@ func Init() {
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
legacyRouting := false
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
parsed, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
}
}
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
if legacyRouting {
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
return false
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !android
//go:build !linux && !windows && !darwin
package net

25
client/net/env_mobile.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,5 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows
//go:build !linux && !windows && !darwin
package net

160
client/net/net_darwin.go Normal file
View File

@@ -0,0 +1,160 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -4979,6 +4979,7 @@ type GetFeaturesResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"`
DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -5027,6 +5028,13 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
return false
}
func (x *GetFeaturesResponse) GetDisableNetworks() bool {
if x != nil {
return x.DisableNetworks
}
return false
}
type TriggerUpdateRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -6472,10 +6480,11 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" +
"\t_username\"\x10\n" +
"\x0eLogoutResponse\"\x14\n" +
"\x12GetFeaturesRequest\"x\n" +
"\x12GetFeaturesRequest\"\xa3\x01\n" +
"\x13GetFeaturesResponse\x12)\n" +
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" +
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" +
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" +
"\x14TriggerUpdateRequest\"M\n" +
"\x15TriggerUpdateResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +

View File

@@ -727,6 +727,7 @@ message GetFeaturesRequest{}
message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
bool disable_networks = 3;
}
message TriggerUpdateRequest {}

View File

@@ -9,6 +9,8 @@ import (
"strings"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
@@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
@@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
@@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -53,6 +53,7 @@ const (
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
errNetworksDisabled = "network selection is disabled by the administrator"
)
var ErrServiceNotUp = errors.New("service is not up")
@@ -88,6 +89,7 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
@@ -104,7 +106,7 @@ type oauthAuthFlow struct {
}
// New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
s := &Server{
rootCtx: ctx,
logFile: logFile,
@@ -113,6 +115,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
networksDisabled: networksDisabled,
jwtCache: newJWTCache(),
}
agent := &serverAgent{s}
@@ -1628,6 +1631,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
DisableNetworks: s.networksDisabled,
}
return features, nil

View File

@@ -36,6 +36,7 @@ import (
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false)
s := New(ctx, "debug", "", false, false, false)
s.config = config
@@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false, false)
s := New(ctx, "console", "", false, false, false)
err = s.Start()
require.NoError(t, err)
@@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false, false)
s := New(ctx, "console", "", false, false, false)
err = s.Start()
require.NoError(t, err)
@@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
@@ -329,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false)
s := New(ctx, "console", "", false, false, false)
rosenpassEnabled := true
rosenpassPermissive := true

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error {
}
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
hostList := strings.Join(deduplicatedPatterns, ",")
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}

View File

@@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
}
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable
t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -2,7 +2,6 @@ package system
import (
"context"
"net"
"net/netip"
"strings"
@@ -145,59 +144,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v
}
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,12 +2,16 @@ package system
import (
"context"
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
@@ -15,11 +19,24 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
// Convert fixed-size byte arrays to Go strings
sysName := extractOsName(ctx, "sysName")
swVersion := extractOsVersion(ctx, "swVersion")
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
gio := &Info{
Kernel: sysName,
OSVersion: swVersion,
Platform: "unknown",
OS: sysName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: swVersion,
NetworkAddresses: addrs,
}
gio.Hostname = extractDeviceName(ctx, "hostname")
gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
@@ -27,6 +44,66 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -0,0 +1,66 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

View File

@@ -314,6 +314,7 @@ type serviceClient struct {
lastNotifiedVersion string
settingsEnabled bool
profilesEnabled bool
networksEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -368,6 +369,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
networksEnabled: true,
}
s.eventHandler = newEventHandler(s)
@@ -920,8 +922,10 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetIcon(s.icConnectedDot)
s.mUp.Disable()
s.mDown.Enable()
s.mNetworks.Enable()
s.mExitNode.Enable()
if s.networksEnabled {
s.mNetworks.Enable()
s.mExitNode.Enable()
}
s.startExitNodeRefresh()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -1093,14 +1097,14 @@ func (s *serviceClient) onTrayReady() {
s.getSrvConfig()
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
for {
// Check features before status so menus respect disable flags before being enabled
s.checkAndUpdateFeatures()
err := s.updateStatus()
if err != nil {
log.Errorf("error while updating status: %v", err)
}
// Check features periodically to handle daemon restarts
s.checkAndUpdateFeatures()
time.Sleep(2 * time.Second)
}
}()
@@ -1299,6 +1303,16 @@ func (s *serviceClient) checkAndUpdateFeatures() {
s.mProfile.setEnabled(profilesEnabled)
}
}
// Update networks and exit node menus based on current features
s.networksEnabled = features == nil || !features.DisableNetworks
if s.networksEnabled && s.connected {
s.mNetworks.Enable()
s.mExitNode.Enable()
} else {
s.mNetworks.Disable()
s.mExitNode.Disable()
}
}
// getFeatures from the daemon to determine which features are enabled/disabled.

View File

@@ -119,6 +119,8 @@ server:
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.

36
go.mod
View File

@@ -17,12 +17,12 @@ require (
github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.48.0
golang.org/x/crypto v0.49.0
golang.org/x/sys v0.42.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.79.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)
@@ -71,7 +71,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
@@ -115,13 +115,13 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
golang.org/x/mobile v0.0.0-20251113184115-a159579294ab
golang.org/x/mod v0.32.0
golang.org/x/net v0.51.0
golang.org/x/oauth2 v0.34.0
golang.org/x/sync v0.19.0
golang.org/x/term v0.40.0
golang.org/x/time v0.14.0
google.golang.org/api v0.257.0
golang.org/x/mod v0.33.0
golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.41.0
golang.org/x/time v0.15.0
google.golang.org/api v0.276.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7
@@ -131,7 +131,7 @@ require (
)
require (
cloud.google.com/go/auth v0.17.0 // indirect
cloud.google.com/go/auth v0.20.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
dario.cat/mergo v1.0.1 // indirect
@@ -210,8 +210,8 @@ require (
github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
github.com/gorilla/handlers v1.5.2 // indirect
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
@@ -295,16 +295,16 @@ require (
github.com/zeebo/blake3 v0.2.3 // indirect
go.mongodb.org/mongo-driver v1.17.9 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/image v0.33.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/tools v0.42.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
@@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

84
go.sum
View File

@@ -1,5 +1,5 @@
cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4=
cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ=
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
@@ -285,10 +285,10 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ=
github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo=
github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc=
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
@@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
@@ -449,12 +447,14 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -664,8 +664,8 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
@@ -707,8 +707,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ=
@@ -725,8 +725,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
@@ -745,11 +745,11 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -761,8 +761,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -811,8 +811,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -824,10 +824,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
@@ -839,8 +839,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -851,19 +851,19 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA=
google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY=
google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4=
google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=

View File

@@ -182,6 +182,23 @@ read_enable_proxy() {
return 0
}
read_enable_crowdsec() {
echo "" > /dev/stderr
echo "Do you want to enable CrowdSec IP reputation blocking?" > /dev/stderr
echo "CrowdSec checks client IPs against a community threat intelligence database" > /dev/stderr
echo "and blocks known malicious sources before they reach your services." > /dev/stderr
echo "A local CrowdSec LAPI container will be added to your deployment." > /dev/stderr
echo -n "Enable CrowdSec? [y/N]: " > /dev/stderr
read -r CHOICE < /dev/tty
if [[ "$CHOICE" =~ ^[Yy]$ ]]; then
echo "true"
else
echo "false"
fi
return 0
}
read_traefik_acme_email() {
echo "" > /dev/stderr
echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr
@@ -297,6 +314,10 @@ initialize_default_values() {
# NetBird Proxy configuration
ENABLE_PROXY="false"
PROXY_TOKEN=""
# CrowdSec configuration
ENABLE_CROWDSEC="false"
CROWDSEC_BOUNCER_KEY=""
return 0
}
@@ -325,6 +346,9 @@ configure_reverse_proxy() {
if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
ENABLE_PROXY=$(read_enable_proxy)
if [[ "$ENABLE_PROXY" == "true" ]]; then
ENABLE_CROWDSEC=$(read_enable_crowdsec)
fi
fi
# Handle external Traefik-specific prompts (option 1)
@@ -354,7 +378,7 @@ check_existing_installation() {
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt && rm -rf crowdsec/"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1
fi
@@ -375,6 +399,9 @@ generate_configuration_files() {
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
# TCP ServersTransport for PROXY protocol v2 to the proxy backend
render_traefik_dynamic > traefik-dynamic.yaml
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
mkdir -p crowdsec
fi
fi
;;
1)
@@ -417,8 +444,12 @@ start_services_and_show_instructions() {
if [[ "$ENABLE_PROXY" == "true" ]]; then
# Phase 1: Start core services (without proxy)
local core_services="traefik dashboard netbird-server"
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
core_services="$core_services crowdsec"
fi
echo "Starting core services..."
$DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server
$DOCKER_COMPOSE_COMMAND up -d $core_services
sleep 3
wait_management_proxy traefik
@@ -438,7 +469,33 @@ start_services_and_show_instructions() {
echo "Proxy token created successfully."
# Generate proxy.env with the token
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "Registering CrowdSec bouncer..."
local cs_retries=0
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
cs_retries=$((cs_retries + 1))
if [[ $cs_retries -ge 30 ]]; then
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
echo "You can register a bouncer manually later with:" > /dev/stderr
echo " docker exec netbird-crowdsec cscli bouncers add netbird-proxy -o raw" > /dev/stderr
ENABLE_CROWDSEC="false"
break
fi
sleep 2
done
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
CROWDSEC_BOUNCER_KEY=$($DOCKER_COMPOSE_COMMAND exec -T crowdsec \
cscli bouncers add netbird-proxy -o raw 2>/dev/null)
if [[ -z "$CROWDSEC_BOUNCER_KEY" ]]; then
echo "WARNING: Failed to create CrowdSec bouncer key. Skipping CrowdSec setup." > /dev/stderr
ENABLE_CROWDSEC="false"
else
echo "CrowdSec bouncer registered."
fi
fi
fi
render_proxy_env > proxy.env
# Start proxy service
@@ -525,11 +582,25 @@ render_docker_compose_traefik_builtin() {
# Generate proxy service section and Traefik dynamic config if enabled
local proxy_service=""
local proxy_volumes=""
local crowdsec_service=""
local crowdsec_volumes=""
local traefik_file_provider=""
local traefik_dynamic_volume=""
if [[ "$ENABLE_PROXY" == "true" ]]; then
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
local proxy_depends="
netbird-server:
condition: service_started"
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
proxy_depends="
netbird-server:
condition: service_started
crowdsec:
condition: service_healthy"
fi
proxy_service="
# NetBird Proxy - exposes internal resources to the internet
proxy:
@@ -539,8 +610,7 @@ render_docker_compose_traefik_builtin() {
- 51820:51820/udp
restart: unless-stopped
networks: [netbird]
depends_on:
- netbird-server
depends_on:${proxy_depends}
env_file:
- ./proxy.env
volumes:
@@ -563,6 +633,35 @@ render_docker_compose_traefik_builtin() {
"
proxy_volumes="
netbird_proxy_certs:"
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
crowdsec_service="
crowdsec:
image: crowdsecurity/crowdsec:v1.7.7
container_name: netbird-crowdsec
restart: unless-stopped
networks: [netbird]
environment:
COLLECTIONS: crowdsecurity/linux
volumes:
- ./crowdsec:/etc/crowdsec
- crowdsec_db:/var/lib/crowdsec/data
healthcheck:
test: ["CMD", "cscli", "lapi", "status"]
interval: 10s
timeout: 5s
retries: 15
labels:
- traefik.enable=false
logging:
driver: \"json-file\"
options:
max-size: \"500m\"
max-file: \"2\"
"
crowdsec_volumes="
crowdsec_db:"
fi
fi
cat <<EOF
@@ -675,10 +774,10 @@ $traefik_dynamic_volume
options:
max-size: "500m"
max-file: "2"
${proxy_service}
${proxy_service}${crowdsec_service}
volumes:
netbird_data:
netbird_traefik_letsencrypt:${proxy_volumes}
netbird_traefik_letsencrypt:${proxy_volumes}${crowdsec_volumes}
networks:
netbird:
@@ -783,6 +882,14 @@ NB_PROXY_PROXY_PROTOCOL=true
# Trust Traefik's IP for PROXY protocol headers
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
cat <<EOF
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
NB_PROXY_CROWDSEC_API_KEY=$CROWDSEC_BOUNCER_KEY
EOF
fi
return 0
}
@@ -1172,6 +1279,17 @@ print_builtin_traefik_instructions() {
echo ""
echo " *.$NETBIRD_DOMAIN CNAME $NETBIRD_DOMAIN"
echo ""
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "CrowdSec IP Reputation:"
echo " CrowdSec LAPI is running and connected to the community blocklist."
echo " The proxy will automatically check client IPs against known threats."
echo " Enable CrowdSec per-service in the dashboard under Access Control."
echo ""
echo " To enroll in CrowdSec Console (optional, for dashboard and premium blocklists):"
echo " docker exec netbird-crowdsec cscli console enroll <your-enrollment-key>"
echo " Get your enrollment key at: https://app.crowdsec.net"
echo ""
fi
fi
return 0
}

View File

@@ -132,9 +132,18 @@ func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peer
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
}
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
return c.OnPeerConnectedWithPeer(ctx, accountID, peer)
}
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
// OnPeerConnectedWithPeer is the peer-object variant of OnPeerConnected. It
// skips the internal GetPeerByID and is intended for callers that already
// hold the peer (e.g. the Sync fast path). The accountID parameter is kept
// for symmetry with OnPeerConnected even though the peer object already
// carries it — callers typically have it handy from the surrounding context.
func (c *Controller) OnPeerConnectedWithPeer(ctx context.Context, accountID string, peer *nbpeer.Peer) (chan *network_map.UpdateMessage, error) {
_ = accountID
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
return c.peersUpdateManager.CreateChannel(ctx, peer.ID), nil
}
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {

View File

@@ -35,6 +35,11 @@ type Controller interface {
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
// OnPeerConnectedWithPeer is equivalent to OnPeerConnected but accepts an
// already-fetched peer, skipping the internal GetPeerByID lookup. Intended
// for callers that have already resolved the peer (e.g. the Sync fast path)
// so the controller does not re-read what the caller just read.
OnPeerConnectedWithPeer(ctx context.Context, accountID string, peer *nbpeer.Peer) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
TrackEphemeralPeer(ctx context.Context, peer *nbpeer.Peer)

View File

@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: management/internals/controllers/network_map/interface.go
// Source: ./interface.go
//
// Generated by this command:
//
// mockgen -package network_map -destination=management/internals/controllers/network_map/interface_mock.go -source=management/internals/controllers/network_map/interface.go -build_flags=-mod=mod
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//
// Package network_map is a generated GoMock package.
@@ -145,6 +145,21 @@ func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
}
// OnPeerConnectedWithPeer mocks base method.
func (m *MockController) OnPeerConnectedWithPeer(ctx context.Context, accountID string, arg2 *peer.Peer) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerConnectedWithPeer", ctx, accountID, arg2)
ret0, _ := ret[0].(chan *UpdateMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OnPeerConnectedWithPeer indicates an expected call of OnPeerConnectedWithPeer.
func (mr *MockControllerMockRecorder) OnPeerConnectedWithPeer(ctx, accountID, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnectedWithPeer", reflect.TypeOf((*MockController)(nil).OnPeerConnectedWithPeer), ctx, accountID, arg2)
}
// OnPeerDisconnected mocks base method.
func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
m.ctrl.T.Helper()

View File

@@ -7,6 +7,7 @@ import (
"testing"
"time"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -18,6 +19,7 @@ import (
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -29,6 +31,13 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
func testCacheStore(t *testing.T) cachestore.StoreInterface {
t.Helper()
s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
return s
}
func TestInitializeServiceForCreate(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
@@ -422,10 +431,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
t.Helper()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
return srv
}
@@ -703,10 +710,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
},
}
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
@@ -1128,10 +1133,8 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
@@ -26,6 +27,7 @@ import (
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/store"
@@ -58,6 +60,18 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics {
})
}
// CacheStore returns a shared cache store backed by Redis or in-memory depending on the environment.
// All consumers should reuse this store to avoid creating multiple Redis connections.
func (s *BaseServer) CacheStore() cachestore.StoreInterface {
return Create(s, func() cachestore.StoreInterface {
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultStoreMaxTimeout, nbcache.DefaultStoreCleanupInterval, nbcache.DefaultStoreMaxConn)
if err != nil {
log.Fatalf("failed to create shared cache store: %v", err)
}
return cs
})
}
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
@@ -149,7 +163,9 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
peerSerialCache := nbgrpc.NewPeerSerialCache(context.Background(), s.CacheStore(), nbgrpc.DefaultPeerSerialCacheTTL)
fastPathFlag := nbgrpc.RunFastPathFlagRoutine(context.Background(), s.CacheStore(), nbgrpc.DefaultFastPathFlagInterval, nbgrpc.DefaultFastPathFlagKey)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), peerSerialCache, fastPathFlag, s.CacheStore())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
@@ -195,10 +211,7 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), s.CacheStore())
log.Info("One-time token store initialized for proxy authentication")
return tokenStore
})
@@ -206,11 +219,7 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
return Create(s, func() *nbgrpc.PKCEVerifierStore {
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create PKCE verifier store: %v", err)
}
return pkceStore
return nbgrpc.NewPKCEVerifierStore(context.Background(), s.CacheStore())
})
}

View File

@@ -41,7 +41,8 @@ func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValida
context.Background(),
s.PeersManager(),
s.SettingsManager(),
s.EventStore())
s.EventStore(),
s.CacheStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}

View File

@@ -100,7 +100,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy, s.CacheStore())
if err != nil {
log.Fatalf("failed to create account service: %v", err)
}

View File

@@ -0,0 +1,48 @@
// Package fastpathcache exposes the key prefixes and delete helpers for the
// Sync fast-path caches so mutation sites outside the gRPC server package
// can invalidate stale entries without a circular import on the grpc
// package that owns the read-side cache wrappers.
package fastpathcache
import (
"context"
"github.com/eko/gocache/lib/v4/cache"
cachestore "github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
)
const (
// ExtraSettingsKeyPrefix matches the prefix used by the read-side
// extraSettingsCache in management/internals/shared/grpc. Keep these in
// sync; drift would leak stale reads on mutations.
ExtraSettingsKeyPrefix = "extra-settings:"
// PeerGroupsKeyPrefix matches the prefix used by the read-side
// peerGroupsCache in management/internals/shared/grpc.
PeerGroupsKeyPrefix = "peer-groups:"
)
// InvalidateExtraSettings removes the cached ExtraSettings entry for the
// given account from the shared cache store. Safe to call with a nil store
// and safe to call when no entry exists. Errors are swallowed at debug level
// so mutation flows never fail because of a cache hiccup.
func InvalidateExtraSettings(ctx context.Context, store cachestore.StoreInterface, accountID string) {
if store == nil {
return
}
if err := cache.New[string](store).Delete(ctx, ExtraSettingsKeyPrefix+accountID); err != nil {
log.WithContext(ctx).Debugf("fastpathcache: invalidate extra settings for %s: %v", accountID, err)
}
}
// InvalidatePeerGroups removes the cached peer-groups entry for a peer. Safe
// to call with a nil store and safe to call when no entry exists.
func InvalidatePeerGroups(ctx context.Context, store cachestore.StoreInterface, peerID string) {
if store == nil {
return
}
if err := cache.New[string](store).Delete(ctx, PeerGroupsKeyPrefix+peerID); err != nil {
log.WithContext(ctx).Debugf("fastpathcache: invalidate peer groups for %s: %v", peerID, err)
}
}

View File

@@ -0,0 +1,122 @@
package grpc
import (
"context"
"encoding/json"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/shared/fastpathcache"
nbtypes "github.com/netbirdio/netbird/management/server/types"
)
const (
extraSettingsCacheKeyPrefix = fastpathcache.ExtraSettingsKeyPrefix
peerGroupsCacheKeyPrefix = fastpathcache.PeerGroupsKeyPrefix
// DefaultExtraSettingsCacheTTL bounds how long a cached ExtraSettings
// blob survives. Settings rarely change; a ~30s window is cheap and
// bounded by the fact that a change also rotates through recordPeerSync
// writes (which don't affect this cache, but client reconnects do).
DefaultExtraSettingsCacheTTL = 30 * time.Second
// DefaultPeerGroupsCacheTTL bounds how long a cached peer group set
// survives. Shorter than ExtraSettings because group membership changes
// have user-visible authz implications.
DefaultPeerGroupsCacheTTL = 15 * time.Second
)
// extraSettingsCache caches the ExtraSettings JSON per account so the fast
// path's buildFastPathResponse can skip GetExtraSettings on cache hit.
// TTL-based; staleness window is ~DefaultExtraSettingsCacheTTL.
type extraSettingsCache struct {
cache *cache.Cache[string]
ctx context.Context
ttl time.Duration
}
func newExtraSettingsCache(ctx context.Context, cacheStore store.StoreInterface, ttl time.Duration) *extraSettingsCache {
if cacheStore == nil {
return nil
}
return &extraSettingsCache{cache: cache.New[string](cacheStore), ctx: ctx, ttl: ttl}
}
func (c *extraSettingsCache) get(accountID string) (*nbtypes.ExtraSettings, bool) {
if c == nil {
return nil, false
}
raw, err := c.cache.Get(c.ctx, extraSettingsCacheKeyPrefix+accountID)
if err != nil {
return nil, false
}
var es nbtypes.ExtraSettings
if err := json.Unmarshal([]byte(raw), &es); err != nil {
log.Debugf("extra settings cache: unmarshal for %s: %v", accountID, err)
return nil, false
}
return &es, true
}
func (c *extraSettingsCache) set(accountID string, es *nbtypes.ExtraSettings) {
if c == nil || es == nil {
return
}
payload, err := json.Marshal(es)
if err != nil {
log.Debugf("extra settings cache: marshal for %s: %v", accountID, err)
return
}
if err := c.cache.Set(c.ctx, extraSettingsCacheKeyPrefix+accountID, string(payload), store.WithExpiration(c.ttl)); err != nil {
log.Debugf("extra settings cache: set for %s: %v", accountID, err)
}
}
// peerGroupsCache caches the list of group IDs a peer belongs to so the fast
// path's buildFastPathResponse can skip GetPeerGroupIDs on cache hit. The
// cache key includes the peerID; group membership changes propagate via TTL.
type peerGroupsCache struct {
cache *cache.Cache[string]
ctx context.Context
ttl time.Duration
}
func newPeerGroupsCache(ctx context.Context, cacheStore store.StoreInterface, ttl time.Duration) *peerGroupsCache {
if cacheStore == nil {
return nil
}
return &peerGroupsCache{cache: cache.New[string](cacheStore), ctx: ctx, ttl: ttl}
}
func (c *peerGroupsCache) get(peerID string) ([]string, bool) {
if c == nil {
return nil, false
}
raw, err := c.cache.Get(c.ctx, peerGroupsCacheKeyPrefix+peerID)
if err != nil {
return nil, false
}
var ids []string
if err := json.Unmarshal([]byte(raw), &ids); err != nil {
log.Debugf("peer groups cache: unmarshal for %s: %v", peerID, err)
return nil, false
}
return ids, true
}
func (c *peerGroupsCache) set(peerID string, ids []string) {
if c == nil {
return
}
payload, err := json.Marshal(ids)
if err != nil {
log.Debugf("peer groups cache: marshal for %s: %v", peerID, err)
return
}
if err := c.cache.Set(c.ctx, peerGroupsCacheKeyPrefix+peerID, string(payload), store.WithExpiration(c.ttl)); err != nil {
log.Debugf("peer groups cache: set for %s: %v", peerID, err)
}
}

View File

@@ -0,0 +1,131 @@
package grpc
import (
"context"
"errors"
"strings"
"sync/atomic"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
)
const (
// DefaultFastPathFlagInterval is the default poll interval for the Sync
// fast-path feature flag. Kept lower than the log-level overrider because
// operators will want the toggle to propagate quickly during rollout.
DefaultFastPathFlagInterval = 1 * time.Minute
// DefaultFastPathFlagKey is the cache key polled by RunFastPathFlagRoutine
// when the caller does not provide an override.
DefaultFastPathFlagKey = "peerSyncFastPath"
)
// FastPathFlag exposes the current on/off state of the Sync fast path. The
// zero value and a nil receiver both report disabled, so callers can always
// treat the flag as a non-nil gate without an additional nil check.
type FastPathFlag struct {
enabled atomic.Bool
}
// NewFastPathFlag returns a FastPathFlag whose state is set to the given
// value. Callers that need the runtime toggle should use
// RunFastPathFlagRoutine instead; this constructor is meant for tests and
// for consumers that want to force the flag on or off.
func NewFastPathFlag(enabled bool) *FastPathFlag {
f := &FastPathFlag{}
f.setEnabled(enabled)
return f
}
// Enabled reports whether the Sync fast path is currently enabled for this
// replica. A nil receiver reports false so a disabled build or test can pass
// a nil flag and skip the fast path entirely.
func (f *FastPathFlag) Enabled() bool {
if f == nil {
return false
}
return f.enabled.Load()
}
func (f *FastPathFlag) setEnabled(v bool) {
if f == nil {
return
}
f.enabled.Store(v)
}
// RunFastPathFlagRoutine starts a background goroutine that polls the shared
// cache store for the Sync fast-path feature flag and updates the returned
// FastPathFlag accordingly. When cacheStore is nil the routine returns a
// handle that stays permanently disabled, so every Sync falls back to the
// full network map path.
//
// The shared store is Redis-backed when NB_CACHE_REDIS_ADDRESS is set (so the
// flag is toggled cluster-wide by writing the key in Redis) and falls back to
// an in-process gocache otherwise, which is enough for single-replica dev and
// test setups.
//
// The routine fails closed: any store read error (other than a plain "key not
// found" miss) disables the flag until Redis confirms it is enabled again.
func RunFastPathFlagRoutine(ctx context.Context, cacheStore store.StoreInterface, interval time.Duration, flagKey string) *FastPathFlag {
flag := &FastPathFlag{}
if cacheStore == nil {
log.Infof("Shared cache store not provided. Sync fast path disabled")
return flag
}
if flagKey == "" {
flagKey = DefaultFastPathFlagKey
}
flagCache := cache.New[string](cacheStore)
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
refresh := func() {
getCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
value, err := flagCache.Get(getCtx, flagKey)
if err != nil {
var notFound *store.NotFound
if !errors.As(err, &notFound) {
log.Errorf("Sync fast-path flag refresh: %v; disabling fast path", err)
}
flag.setEnabled(false)
return
}
flag.setEnabled(parseFastPathFlag(value))
}
refresh()
for {
select {
case <-ctx.Done():
log.Infof("Stopping Sync fast-path flag routine")
return
case <-ticker.C:
refresh()
}
}
}()
return flag
}
// parseFastPathFlag accepts "1" or "true" (any casing, surrounding whitespace
// tolerated) as enabled and treats every other value as disabled.
func parseFastPathFlag(value string) bool {
v := strings.TrimSpace(value)
if v == "1" {
return true
}
return strings.EqualFold(v, "true")
}

View File

@@ -0,0 +1,176 @@
package grpc
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/eko/gocache/lib/v4/store"
gocache_store "github.com/eko/gocache/store/go_cache/v4"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseFastPathFlag(t *testing.T) {
tests := []struct {
name string
value string
want bool
}{
{"one", "1", true},
{"true lowercase", "true", true},
{"true uppercase", "TRUE", true},
{"true mixed case", "True", true},
{"true with whitespace", " true ", true},
{"zero", "0", false},
{"false", "false", false},
{"empty", "", false},
{"yes", "yes", false},
{"garbage", "garbage", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, parseFastPathFlag(tt.value), "parseFastPathFlag(%q)", tt.value)
})
}
}
func TestFastPathFlag_EnabledDefaultsFalse(t *testing.T) {
flag := &FastPathFlag{}
assert.False(t, flag.Enabled(), "zero value flag should report disabled")
}
func TestFastPathFlag_NilSafeEnabled(t *testing.T) {
var flag *FastPathFlag
assert.False(t, flag.Enabled(), "nil flag should report disabled without panicking")
}
func TestFastPathFlag_SetEnabled(t *testing.T) {
flag := &FastPathFlag{}
flag.setEnabled(true)
assert.True(t, flag.Enabled(), "flag should report enabled after setEnabled(true)")
flag.setEnabled(false)
assert.False(t, flag.Enabled(), "flag should report disabled after setEnabled(false)")
}
func TestNewFastPathFlag(t *testing.T) {
assert.True(t, NewFastPathFlag(true).Enabled(), "NewFastPathFlag(true) should report enabled")
assert.False(t, NewFastPathFlag(false).Enabled(), "NewFastPathFlag(false) should report disabled")
}
func TestRunFastPathFlagRoutine_NilStoreStaysDisabled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
flag := RunFastPathFlagRoutine(ctx, nil, 50*time.Millisecond, "peerSyncFastPath")
require.NotNil(t, flag, "RunFastPathFlagRoutine should always return a non-nil flag")
assert.False(t, flag.Enabled(), "flag should stay disabled when no cache store is provided")
time.Sleep(150 * time.Millisecond)
assert.False(t, flag.Enabled(), "flag should remain disabled after wait when no cache store is provided")
}
func TestRunFastPathFlagRoutine_ReadsFlagFromStore(t *testing.T) {
cacheStore := newFastPathTestStore(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
flag := RunFastPathFlagRoutine(ctx, cacheStore, 50*time.Millisecond, "peerSyncFastPath")
require.NotNil(t, flag)
assert.False(t, flag.Enabled(), "flag should start disabled when the key is missing")
require.NoError(t, cacheStore.Set(ctx, "peerSyncFastPath", "1"), "seed flag=1 into shared store")
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should flip enabled after the key is set to 1")
require.NoError(t, cacheStore.Set(ctx, "peerSyncFastPath", "0"), "override flag=0 into shared store")
assert.Eventually(t, func() bool {
return !flag.Enabled()
}, 2*time.Second, 25*time.Millisecond, "flag should flip disabled after the key is set to 0")
require.NoError(t, cacheStore.Delete(ctx, "peerSyncFastPath"), "remove flag key")
assert.Eventually(t, func() bool {
return !flag.Enabled()
}, 2*time.Second, 25*time.Millisecond, "flag should stay disabled after the key is deleted")
require.NoError(t, cacheStore.Set(ctx, "peerSyncFastPath", "true"), "enable via string true")
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should accept \"true\" as enabled")
}
func TestRunFastPathFlagRoutine_MissingKeyKeepsDisabled(t *testing.T) {
cacheStore := newFastPathTestStore(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
flag := RunFastPathFlagRoutine(ctx, cacheStore, 50*time.Millisecond, "peerSyncFastPathAbsent")
require.NotNil(t, flag)
time.Sleep(200 * time.Millisecond)
assert.False(t, flag.Enabled(), "flag should stay disabled when the key is missing from the store")
}
func TestRunFastPathFlagRoutine_DefaultKeyUsedWhenEmpty(t *testing.T) {
cacheStore := newFastPathTestStore(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
require.NoError(t, cacheStore.Set(ctx, DefaultFastPathFlagKey, "1"), "seed default key")
flag := RunFastPathFlagRoutine(ctx, cacheStore, 50*time.Millisecond, "")
require.NotNil(t, flag)
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "empty flagKey should fall back to DefaultFastPathFlagKey")
}
func newFastPathTestStore(t *testing.T) store.StoreInterface {
t.Helper()
return gocache_store.NewGoCache(gocache.New(5*time.Minute, 10*time.Minute))
}
func TestRunFastPathFlagRoutine_FailsClosedOnReadError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
s := &flakyStore{
StoreInterface: newFastPathTestStore(t),
}
require.NoError(t, s.Set(ctx, "peerSyncFastPath", "1"), "seed flag enabled")
flag := RunFastPathFlagRoutine(ctx, s, 50*time.Millisecond, "peerSyncFastPath")
require.NotNil(t, flag)
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should flip enabled while store reads succeed")
s.setGetError(errors.New("simulated transient store failure"))
assert.Eventually(t, func() bool {
return !flag.Enabled()
}, 2*time.Second, 25*time.Millisecond, "flag should flip disabled on store read error (fail-closed)")
s.setGetError(nil)
assert.Eventually(t, flag.Enabled, 2*time.Second, 25*time.Millisecond, "flag should recover once the store read succeeds again")
}
// flakyStore wraps a real store and lets tests inject a transient Get error
// without affecting Set/Delete. Used to exercise fail-closed behaviour.
type flakyStore struct {
store.StoreInterface
getErr atomic.Pointer[error]
}
func (f *flakyStore) Get(ctx context.Context, key any) (any, error) {
if errPtr := f.getErr.Load(); errPtr != nil && *errPtr != nil {
return nil, *errPtr
}
return f.StoreInterface.Get(ctx, key)
}
func (f *flakyStore) setGetError(err error) {
if err == nil {
f.getErr.Store(nil)
return
}
f.getErr.Store(&err)
}

View File

@@ -14,8 +14,6 @@ import (
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
type tokenMetadata struct {
@@ -32,17 +30,12 @@ type OneTimeTokenStore struct {
ctx context.Context
}
// NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
// NewOneTimeTokenStore creates a token store using the provided shared cache store.
func NewOneTimeTokenStore(ctx context.Context, cacheStore store.StoreInterface) *OneTimeTokenStore {
return &OneTimeTokenStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
}
// GenerateToken creates a new cryptographically secure one-time token

View File

@@ -0,0 +1,82 @@
package grpc
import (
"context"
"encoding/json"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
)
const (
peerSerialCacheKeyPrefix = "peer-sync:"
// DefaultPeerSerialCacheTTL bounds how long a cached serial survives. If the
// cache write on a full-map send ever drops, entries naturally expire and
// the next Sync falls back to the full path, re-priming the cache.
DefaultPeerSerialCacheTTL = 24 * time.Hour
)
// PeerSerialCache records the NetworkMap serial and meta hash last delivered to
// each peer on Sync. Lookups are used to skip full network map computation when
// the peer already has the latest state. Backed by the shared cache store so
// entries survive management replicas sharing a Redis instance.
type PeerSerialCache struct {
cache *cache.Cache[string]
ctx context.Context
ttl time.Duration
}
// NewPeerSerialCache creates a cache wrapper bound to the shared cache store.
// The ttl is applied to every Set call; entries older than ttl are treated as
// misses so the server eventually converges to delivering a full map even if
// an earlier Set was lost.
func NewPeerSerialCache(ctx context.Context, cacheStore store.StoreInterface, ttl time.Duration) *PeerSerialCache {
return &PeerSerialCache{
cache: cache.New[string](cacheStore),
ctx: ctx,
ttl: ttl,
}
}
// Get returns the entry previously recorded for this peer and whether a valid
// entry was found. A cache miss or any read error is reported as a miss so
// callers fall back to the full map path.
func (c *PeerSerialCache) Get(pubKey string) (peerSyncEntry, bool) {
raw, err := c.cache.Get(c.ctx, peerSerialCacheKeyPrefix+pubKey)
if err != nil {
return peerSyncEntry{}, false
}
entry := peerSyncEntry{}
if err := json.Unmarshal([]byte(raw), &entry); err != nil {
log.Debugf("peer serial cache: unmarshal entry for %s: %v", pubKey, err)
return peerSyncEntry{}, false
}
return entry, true
}
// Set records what the server most recently delivered to this peer. Errors are
// logged at debug level so cache outages degrade gracefully into the full map
// path on the next Sync rather than failing the current Sync.
func (c *PeerSerialCache) Set(pubKey string, entry peerSyncEntry) {
payload, err := json.Marshal(entry)
if err != nil {
log.Debugf("peer serial cache: marshal entry for %s: %v", pubKey, err)
return
}
if err := c.cache.Set(c.ctx, peerSerialCacheKeyPrefix+pubKey, string(payload), store.WithExpiration(c.ttl)); err != nil {
log.Debugf("peer serial cache: set entry for %s: %v", pubKey, err)
}
}
// Delete removes any cached entry for this peer. Used on Login so the next
// Sync always sees a miss and delivers a full map.
func (c *PeerSerialCache) Delete(pubKey string) {
if err := c.cache.Delete(c.ctx, peerSerialCacheKeyPrefix+pubKey); err != nil {
log.Debugf("peer serial cache: delete entry for %s: %v", pubKey, err)
}
}

View File

@@ -0,0 +1,116 @@
package grpc
import "testing"
func TestShouldSkipNetworkMap(t *testing.T) {
tests := []struct {
name string
goOS string
hit bool
cached peerSyncEntry
currentSerial uint64
incomingMeta uint64
want bool
}{
{
name: "android never skips even on clean cache hit",
goOS: "android",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: false,
},
{
name: "android uppercase never skips",
goOS: "Android",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: false,
},
{
name: "cache miss forces full path",
goOS: "linux",
hit: false,
cached: peerSyncEntry{},
currentSerial: 42,
incomingMeta: 7,
want: false,
},
{
name: "serial mismatch forces full path",
goOS: "linux",
hit: true,
cached: peerSyncEntry{Serial: 41, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: false,
},
{
name: "meta mismatch forces full path",
goOS: "linux",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 9,
want: false,
},
{
name: "clean hit on linux skips",
goOS: "linux",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: true,
},
{
name: "clean hit on darwin skips",
goOS: "darwin",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: true,
},
{
name: "clean hit on windows skips",
goOS: "windows",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: true,
},
{
name: "zero current serial never skips",
goOS: "linux",
hit: true,
cached: peerSyncEntry{Serial: 0, MetaHash: 7},
currentSerial: 0,
incomingMeta: 7,
want: false,
},
{
name: "empty goos treated as non-android and skips",
goOS: "",
hit: true,
cached: peerSyncEntry{Serial: 42, MetaHash: 7},
currentSerial: 42,
incomingMeta: 7,
want: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := shouldSkipNetworkMap(tc.goOS, tc.hit, tc.cached, tc.currentSerial, tc.incomingMeta)
if got != tc.want {
t.Fatalf("shouldSkipNetworkMap(%q, hit=%v, cached=%+v, current=%d, meta=%d) = %v, want %v",
tc.goOS, tc.hit, tc.cached, tc.currentSerial, tc.incomingMeta, got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,134 @@
package grpc
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
func newTestPeerSerialCache(t *testing.T, ttl, cleanup time.Duration) *PeerSerialCache {
t.Helper()
s, err := nbcache.NewStore(context.Background(), ttl, cleanup, 100)
require.NoError(t, err, "cache store must initialise")
return NewPeerSerialCache(context.Background(), s, ttl)
}
func TestPeerSerialCache_GetSetDelete(t *testing.T) {
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
key := "pubkey-aaa"
_, hit := c.Get(key)
assert.False(t, hit, "empty cache must miss")
c.Set(key, peerSyncEntry{Serial: 42, MetaHash: 7})
entry, hit := c.Get(key)
require.True(t, hit, "after Set, Get must hit")
assert.Equal(t, uint64(42), entry.Serial, "serial roundtrip")
assert.Equal(t, uint64(7), entry.MetaHash, "meta hash roundtrip")
c.Delete(key)
_, hit = c.Get(key)
assert.False(t, hit, "after Delete, Get must miss")
}
func TestPeerSerialCache_GetMissReturnsZero(t *testing.T) {
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
entry, hit := c.Get("missing")
assert.False(t, hit, "miss must report false")
assert.Equal(t, peerSyncEntry{}, entry, "miss must return zero value")
}
func TestPeerSerialCache_TTLExpiry(t *testing.T) {
c := newTestPeerSerialCache(t, 100*time.Millisecond, 10*time.Millisecond)
key := "pubkey-ttl"
c.Set(key, peerSyncEntry{Serial: 1, MetaHash: 1})
time.Sleep(250 * time.Millisecond)
_, hit := c.Get(key)
assert.False(t, hit, "entry must expire after TTL")
}
func TestPeerSerialCache_OverwriteUpdatesValue(t *testing.T) {
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
key := "pubkey-overwrite"
c.Set(key, peerSyncEntry{Serial: 1, MetaHash: 1})
c.Set(key, peerSyncEntry{Serial: 99, MetaHash: 123})
entry, hit := c.Get(key)
require.True(t, hit, "overwritten key must still be present")
assert.Equal(t, uint64(99), entry.Serial, "overwrite updates serial")
assert.Equal(t, uint64(123), entry.MetaHash, "overwrite updates meta hash")
}
func TestPeerSerialCache_IsolatedPerKey(t *testing.T) {
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
c.Set("a", peerSyncEntry{Serial: 1, MetaHash: 1})
c.Set("b", peerSyncEntry{Serial: 2, MetaHash: 2})
a, hitA := c.Get("a")
b, hitB := c.Get("b")
require.True(t, hitA, "key a must hit")
require.True(t, hitB, "key b must hit")
assert.Equal(t, uint64(1), a.Serial, "key a serial")
assert.Equal(t, uint64(2), b.Serial, "key b serial")
c.Delete("a")
_, hitA = c.Get("a")
_, hitB = c.Get("b")
assert.False(t, hitA, "deleting a must not affect b")
assert.True(t, hitB, "b must remain after a delete")
}
func TestPeerSerialCache_Concurrent(t *testing.T) {
c := newTestPeerSerialCache(t, time.Minute, time.Minute)
var wg sync.WaitGroup
const workers = 50
const iterations = 20
for w := 0; w < workers; w++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := "pubkey"
for i := 0; i < iterations; i++ {
c.Set(key, peerSyncEntry{Serial: uint64(id*iterations + i), MetaHash: uint64(id)})
_, _ = c.Get(key)
}
}(w)
}
wg.Wait()
_, hit := c.Get("pubkey")
assert.True(t, hit, "cache must survive concurrent Set/Get without deadlock")
}
func TestPeerSerialCache_Redis(t *testing.T) {
if os.Getenv(nbcache.RedisStoreEnvVar) == "" {
t.Skipf("set %s to run this test against a real Redis", nbcache.RedisStoreEnvVar)
}
s, err := nbcache.NewStore(context.Background(), time.Minute, 10*time.Second, 10)
require.NoError(t, err, "redis store must initialise")
c := NewPeerSerialCache(context.Background(), s, time.Minute)
key := "redis-pubkey"
c.Set(key, peerSyncEntry{Serial: 42, MetaHash: 7})
entry, hit := c.Get(key)
require.True(t, hit, "redis hit expected")
assert.Equal(t, uint64(42), entry.Serial)
c.Delete(key)
}

View File

@@ -8,8 +8,6 @@ import (
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
@@ -19,17 +17,12 @@ type PKCEVerifierStore struct {
ctx context.Context
}
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
// NewPKCEVerifierStore creates a PKCE verifier store using the provided shared cache store.
func NewPKCEVerifierStore(ctx context.Context, cacheStore store.StoreInterface) *PKCEVerifierStore {
return &PKCEVerifierStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
}
// Store saves a PKCE verifier associated with an OAuth state parameter.

View File

@@ -9,13 +9,22 @@ import (
"testing"
"time"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/shared/management/proto"
)
func testCacheStore(t *testing.T) cachestore.StoreInterface {
t.Helper()
s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
return s
}
type testProxyController struct {
mu sync.Mutex
clusterProxies map[string]map[string]struct{}
@@ -114,11 +123,8 @@ func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool {
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
s := &ProxyServiceServer{
tokenStore: tokenStore,
@@ -174,11 +180,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
s := &ProxyServiceServer{
tokenStore: tokenStore,
@@ -211,11 +214,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
ctx := context.Background()
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
s := &ProxyServiceServer{
tokenStore: tokenStore,
@@ -267,8 +267,7 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
func TestOAuthState_NeverTheSame(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
@@ -296,8 +295,7 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
@@ -307,7 +305,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
}
// Old format had only 2 parts: base64(url)|hmac
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
err := s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err = s.ValidateState("base64url|hmac")
@@ -317,8 +315,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
@@ -328,7 +325,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
}
// Store with tampered HMAC
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
err := s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
require.NoError(t, err)
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
@@ -337,8 +334,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
}
func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t))
s := &ProxyServiceServer{
tokenStore: tokenStore,
@@ -410,8 +406,7 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
}
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t))
s := &ProxyServiceServer{
tokenStore: tokenStore,
@@ -442,8 +437,7 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
// scenario for an existing service, verifying the correct update types
// reach the correct clusters.
func TestServiceModifyNotifications(t *testing.T) {
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t))
newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) {
s := &ProxyServiceServer{

View File

@@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
cachestore "github.com/eko/gocache/lib/v4/store"
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
@@ -84,9 +85,35 @@ type Server struct {
reverseProxyManager rpservice.Manager
reverseProxyMu sync.RWMutex
// peerSerialCache lets Sync skip full network map computation when the peer
// already has the latest account serial. A nil cache disables the fast path.
peerSerialCache *PeerSerialCache
// fastPathFlag is the runtime kill switch for the Sync fast path. A nil
// flag or a flag reporting disabled forces every Sync through the full
// network map path.
fastPathFlag *FastPathFlag
// Secondary TTL-based caches used by the Sync fast path to skip DB reads
// for the account's ExtraSettings and a peer's group membership. Both
// are nil-safe and disabled if the shared cache store wasn't provided.
extraSettingsCache *extraSettingsCache
peerGroupsCache *peerGroupsCache
// inflightMarkPeerConnected dedupes the fire-and-forget MarkPeerConnected
// writes kicked off by the fast path. Keys are peer pubkeys; presence
// means a background goroutine is already writing for that peer, so
// concurrent fast-path Syncs for the same peer coalesce to one write.
inflightMarkPeerConnected sync.Map
}
// NewServer creates a new Management server
// NewServer creates a new Management server. peerSerialCache and fastPathFlag
// are both optional; when either is nil or the flag reports disabled, the
// Sync fast path is disabled and every request runs the full map computation,
// matching the pre-cache behaviour. cacheStore is used to back the
// secondary fast-path caches (account serial, ExtraSettings, peer groups);
// a nil store silently disables those caches without affecting correctness.
func NewServer(
config *nbconfig.Config,
accountManager account.Manager,
@@ -98,6 +125,9 @@ func NewServer(
integratedPeerValidator integrated_validator.IntegratedValidator,
networkMapController network_map.Controller,
oAuthConfigProvider idp.OAuthConfigProvider,
peerSerialCache *PeerSerialCache,
fastPathFlag *FastPathFlag,
cacheStore cachestore.StoreInterface,
) (*Server, error) {
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
@@ -145,6 +175,12 @@ func NewServer(
syncLim: syncLim,
syncLimEnabled: syncLimEnabled,
peerSerialCache: peerSerialCache,
fastPathFlag: fastPathFlag,
extraSettingsCache: newExtraSettingsCache(context.Background(), cacheStore, DefaultExtraSettingsCacheTTL),
peerGroupsCache: newPeerGroupsCache(context.Background(), cacheStore, DefaultPeerGroupsCacheTTL),
}, nil
}
@@ -233,24 +269,50 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
ctx := srv.Context()
syncReq := &proto.SyncRequest{}
parseStart := time.Now()
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
s.syncSem.Add(-1)
return err
}
log.WithContext(ctx).Debugf("fast path: parseRequest took %s", time.Since(parseStart))
realIP := getRealIP(ctx)
sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
userID, err := s.accountManager.GetUserIDByPeerKey(ctx, peerKey.String())
if err != nil {
s.syncSem.Add(-1)
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
metahashed := metaHash(peerMeta, sRealIP)
// Fast path authorisation short-circuit: if the peer-sync cache has a
// complete entry whose metaHash still matches the incoming request, we can
// skip GetPeerAuthInfo entirely. The entry carries AccountID and HasUser
// so we have everything the loginFilter gate and the rest of the handler
// need. On any mismatch we fall back to the DB read below.
var (
userID string
accountID string
)
cachedEntry, cachedEntryHit := s.lookupPeerAuthFromCache(peerKey.String(), metahashed, peerMeta.GoOS)
if cachedEntryHit {
accountID = cachedEntry.AccountID
if cachedEntry.HasUser {
userID = "cached"
}
return mapError(ctx, err)
log.WithContext(ctx).Debugf("fast path: GetPeerAuthInfo skipped (cache hit)")
} else {
authInfoStart := time.Now()
uid, aid, err := s.accountManager.GetPeerAuthInfo(ctx, peerKey.String())
if err != nil {
s.syncSem.Add(-1)
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return mapError(ctx, err)
}
userID = uid
accountID = aid
log.WithContext(ctx).Debugf("fast path: GetPeerAuthInfo took %s", time.Since(authInfoStart))
}
metahashed := metaHash(peerMeta, sRealIP)
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
@@ -271,19 +333,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
s.syncSem.Add(-1)
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
s.syncSem.Add(-1)
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
@@ -294,7 +343,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
unlock()
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("fast path: acquirePeerLockByUID took %s", time.Since(start))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
@@ -305,6 +354,12 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
took, skipReason, err := s.tryFastPathSync(ctx, reqStart, syncStart, accountID, peerKey, peerMeta, realIP, metahash, srv, &unlock)
if took {
return err
}
log.WithContext(ctx).Debugf("Sync (fast path) skipped reason=%s", skipReason)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
@@ -319,6 +374,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
return err
}
s.recordPeerSyncEntry(peerKey.String(), netMap, metahash, peer)
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
if err != nil {
@@ -340,7 +396,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
return s.handleUpdates(ctx, accountID, peerKey, peer, metahash, updates, srv, syncStart)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
@@ -410,8 +466,9 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
// handleUpdates sends updates to the connected peer until the updates channel is closed.
// It implements a backpressure mechanism that sends the first update immediately,
// then debounces subsequent rapid updates, ensuring only the latest update is sent
// after a quiet period.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
// after a quiet period. peerMetaHash is forwarded to sendUpdate so the peer-sync
// cache can record the serial this peer just received.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerMetaHash uint64, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
// Create a debouncer for this peer connection
@@ -436,7 +493,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerMetaHash, update, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
@@ -450,7 +507,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
}
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
for _, pendingUpdate := range pendingUpdates {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerMetaHash, pendingUpdate, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
@@ -468,7 +525,9 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
// For MessageTypeNetworkMap updates it records the delivered serial in the
// peer-sync cache so a subsequent Sync with the same serial can take the fast path.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerMetaHash uint64, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
@@ -488,6 +547,9 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message")
}
if update.MessageType == network_map.MessageTypeNetworkMap {
s.recordPeerSyncEntryFromUpdate(peerKey.String(), update, peerMetaHash, peer)
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
@@ -772,6 +834,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(ctx, err)
}
s.invalidatePeerSyncEntry(peerKey.String())
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {

View File

@@ -0,0 +1,562 @@
package grpc
import (
"context"
"net"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// peerGroupFetcher returns the group IDs a peer belongs to. It is a dependency
// of buildFastPathResponse so tests can inject a stub without a real store.
type peerGroupFetcher func(ctx context.Context, accountID, peerID string) ([]string, error)
// peerSyncEntry records what the server last delivered to a peer on Sync so we
// can decide whether the next Sync can skip the full network map computation.
// It also carries the minimum peer/auth metadata needed to run the fast path
// without a DB round-trip on cache hit.
type peerSyncEntry struct {
// Serial is the NetworkMap.Serial the server last included in a full map
// delivered to this peer.
Serial uint64
// MetaHash is the metaHash() value of the peer metadata at the time of that
// delivery, used to detect a meta change on reconnect.
MetaHash uint64
// AccountID is the peer's account ID. Cached so the Sync hot path can skip
// GetPeerAuthInfo on cache hit.
AccountID string
// PeerID is the peer's internal ID, needed for network-map subscription
// and update-channel routing.
PeerID string
// PeerKey mirrors the cache key (peer's wireguard pubkey) so the peer
// snapshot carries everything required by cancelPeerRoutines without a
// second store lookup.
PeerKey string
// Ephemeral is the peer's ephemeral flag, used by EphemeralPeersManager
// on subscribe/unsubscribe.
Ephemeral bool
// HasUser is true if the peer is user-owned (peer.UserID != ""). Used in
// place of GetUserIDByPeerKey's result to drive the loginFilter gate on
// cache hit.
HasUser bool
}
// IsComplete reports whether the entry has every field the pure-cache fast
// path needs. Entries written by older code (before step 2) will carry only
// Serial and MetaHash and must fall back to the slow path so the cache is
// repopulated with the full shape.
func (e peerSyncEntry) IsComplete() bool {
return e.AccountID != "" && e.PeerID != "" && e.PeerKey != ""
}
// PeerSnapshot reconstructs the minimum *nbpeer.Peer needed by
// OnPeerConnectedWithPeer, EphemeralPeersManager, handleUpdates,
// cancelPeerRoutines, and buildFastPathResponse.
func (e peerSyncEntry) PeerSnapshot() *nbpeer.Peer {
return &nbpeer.Peer{
ID: e.PeerID,
Key: e.PeerKey,
AccountID: e.AccountID,
Ephemeral: e.Ephemeral,
}
}
// lookupPeerAuthFromCache checks whether the peer-sync cache holds a complete
// entry for this peer with a matching metaHash, so the Sync handler can skip
// the pre-fast-path GetPeerAuthInfo store read. Returns hit=false whenever
// the fast path is disabled, the peer is Android, the cache is empty, the
// entry is from an older shape without snapshot fields, or metaHash differs.
func (s *Server) lookupPeerAuthFromCache(peerPubKey string, incomingMetaHash uint64, goOS string) (peerSyncEntry, bool) {
if s.peerSerialCache == nil {
return peerSyncEntry{}, false
}
if !s.fastPathFlag.Enabled() {
return peerSyncEntry{}, false
}
if strings.EqualFold(goOS, "android") {
return peerSyncEntry{}, false
}
entry, hit := s.peerSerialCache.Get(peerPubKey)
if !hit || !entry.IsComplete() {
return peerSyncEntry{}, false
}
if entry.MetaHash != incomingMetaHash {
return peerSyncEntry{}, false
}
return entry, true
}
// shouldSkipNetworkMap reports whether a Sync request from this peer can be
// answered with a lightweight NetbirdConfig-only response instead of a full
// map computation. All conditions must hold:
// - the peer is not Android (Android's GrpcClient.GetNetworkMap errors on nil map)
// - the cache holds an entry for this peer
// - the cached serial matches the current account serial
// - the cached meta hash matches the incoming meta hash
// - the cached serial is non-zero (guard against uninitialised entries)
//
// recordFastPathSkip bumps the slow-path sync counter with a reason label.
// Used from every early-return site in tryFastPathSync so Grafana can graph
// fast-path misses by cause. No log is emitted here — a single
// "Sync (fast path) skipped reason=<reason>" line is written by the Sync
// handler in server.go once tryFastPathSync returns so a grep for
// "Sync (fast path)" finds both the success and the skip outcomes in one
// shot.
func (s *Server) recordFastPathSkip(_ context.Context, reason string) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSlowPathSync(reason)
}
}
// fastPathSkipReason returns a short reason tag when the eligibility check
// fails, or "" when the fast path should run. Mirrors shouldSkipNetworkMap's
// logic but attributes each individual failure condition so callers can log
// a histogram of fast-path misses.
func fastPathSkipReason(hit bool, cached peerSyncEntry, currentSerial, incomingMetaHash uint64) string {
if !hit {
return "cache_miss"
}
if cached.Serial == 0 {
return "cached_serial_zero"
}
if cached.Serial != currentSerial {
return "serial_mismatch"
}
if cached.MetaHash != incomingMetaHash {
return "meta_mismatch"
}
return ""
}
func shouldSkipNetworkMap(goOS string, hit bool, cached peerSyncEntry, currentSerial, incomingMetaHash uint64) bool {
if strings.EqualFold(goOS, "android") {
return false
}
if !hit {
return false
}
if cached.Serial == 0 {
return false
}
if cached.Serial != currentSerial {
return false
}
if cached.MetaHash != incomingMetaHash {
return false
}
return true
}
// extraSettingsFetcher is the dependency used by buildFastPathResponse to
// obtain ExtraSettings for the peer's account. Matches the shape of the
// method on settings.Manager but as a plain function so production callers
// can wrap it with a cache and tests can inject a stub.
type extraSettingsFetcher func(ctx context.Context, accountID string) (*nbtypes.ExtraSettings, error)
// buildFastPathResponse constructs a SyncResponse containing only NetbirdConfig
// with fresh TURN/Relay tokens, mirroring the shape used by
// TimeBasedAuthSecretsManager when pushing token refreshes. The response omits
// NetworkMap, PeerConfig, Checks and RemotePeers; the client keeps its existing
// state and only refreshes its control-plane credentials.
func buildFastPathResponse(
ctx context.Context,
cfg *nbconfig.Config,
secrets SecretsManager,
fetchExtraSettings extraSettingsFetcher,
fetchGroups peerGroupFetcher,
peer *nbpeer.Peer,
) *proto.SyncResponse {
var turnToken *Token
if cfg != nil && cfg.TURNConfig != nil && cfg.TURNConfig.TimeBasedCredentials {
if t, err := secrets.GenerateTurnToken(); err == nil {
turnToken = t
} else {
log.WithContext(ctx).Warnf("fast path: generate TURN token: %v", err)
}
}
var relayToken *Token
if cfg != nil && cfg.Relay != nil && len(cfg.Relay.Addresses) > 0 {
if t, err := secrets.GenerateRelayToken(); err == nil {
relayToken = t
} else {
log.WithContext(ctx).Warnf("fast path: generate relay token: %v", err)
}
}
var extraSettings *nbtypes.ExtraSettings
if fetchExtraSettings != nil {
if es, err := fetchExtraSettings(ctx, peer.AccountID); err != nil {
log.WithContext(ctx).Debugf("fast path: get extra settings: %v", err)
} else {
extraSettings = es
}
}
nbConfig := toNetbirdConfig(cfg, turnToken, relayToken, extraSettings)
var peerGroups []string
if fetchGroups != nil {
if ids, err := fetchGroups(ctx, peer.AccountID, peer.ID); err != nil {
log.WithContext(ctx).Debugf("fast path: get peer group ids: %v", err)
} else {
peerGroups = ids
}
}
extendStart := time.Now()
nbConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
log.WithContext(ctx).Debugf("fast path: ExtendNetBirdConfig took %s", time.Since(extendStart))
return &proto.SyncResponse{NetbirdConfig: nbConfig}
}
// fetchExtraSettings returns a cached ExtraSettings when available, falling
// back to the settings manager on miss. Populates the cache on miss so
// subsequent fast-path Syncs hit it.
func (s *Server) fetchExtraSettings(ctx context.Context, accountID string) (*nbtypes.ExtraSettings, error) {
if es, ok := s.extraSettingsCache.get(accountID); ok {
log.WithContext(ctx).Debugf("fast path: GetExtraSettings skipped (cache hit)")
return es, nil
}
start := time.Now()
es, err := s.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return nil, err
}
log.WithContext(ctx).Debugf("fast path: GetExtraSettings took %s", time.Since(start))
s.extraSettingsCache.set(accountID, es)
return es, nil
}
// tryFastPathSync decides whether the current Sync can be answered with a
// lightweight NetbirdConfig-only response. When the fast path runs, it takes
// over the whole Sync handler (MarkPeerConnected, send, OnPeerConnected,
// SetupRefresh, handleUpdates) and the returned took value is true.
//
// When took is true the caller must return the accompanying err. When took is
// false the caller falls through to the existing slow path and should log
// "Sync (fast path) skipped reason=<skipReason>" so a single grep on
// "Sync (fast path)" finds both the fast-path successes and misses.
func (s *Server) tryFastPathSync(
ctx context.Context,
reqStart, syncStart time.Time,
accountID string,
peerKey wgtypes.Key,
peerMeta nbpeer.PeerSystemMeta,
realIP net.IP,
peerMetaHash uint64,
srv proto.ManagementService_SyncServer,
unlock *func(),
) (took bool, skipReason string, err error) {
if s.peerSerialCache == nil {
s.recordFastPathSkip(ctx, "cache_disabled")
return false, "cache_disabled", nil
}
if !s.fastPathFlag.Enabled() {
s.recordFastPathSkip(ctx, "flag_off")
return false, "flag_off", nil
}
if strings.EqualFold(peerMeta.GoOS, "android") {
s.recordFastPathSkip(ctx, "android")
return false, "android", nil
}
networkStart := time.Now()
currentSerial, err := s.accountManager.GetStore().GetAccountNetworkSerial(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Debugf("fast path: account network serial lookup error: %v", err)
s.recordFastPathSkip(ctx, "account_network_error")
return false, "account_network_error", nil
}
log.WithContext(ctx).Debugf("fast path: initial GetAccountNetworkSerial took %s", time.Since(networkStart))
eligibilityStart := time.Now()
cached, hit := s.peerSerialCache.Get(peerKey.String())
if reason := fastPathSkipReason(hit, cached, currentSerial, peerMetaHash); reason != "" {
log.WithContext(ctx).Debugf("fast path: eligibility check took %s", time.Since(eligibilityStart))
s.recordFastPathSkip(ctx, reason)
return false, reason, nil
}
log.WithContext(ctx).Debugf("fast path: eligibility check (hit) took %s", time.Since(eligibilityStart))
var cachedPeer *nbpeer.Peer
if cached.IsComplete() {
cachedPeer = cached.PeerSnapshot()
}
peer, updates, committed := s.commitFastPath(ctx, accountID, peerKey, realIP, syncStart, cachedPeer)
if !committed {
s.recordFastPathSkip(ctx, "commit_failed")
return false, "commit_failed", nil
}
// Upgrade the cache only when we had to fetch the peer from the store
// this Sync. In the steady state the cached snapshot lacks UserID (not
// part of PeerSnapshot), so rewriting from it would flip HasUser to
// false and corrupt the entry. A cache-served peer also means the
// entry is already in the full shape, so there's nothing to upgrade.
upgradeCache := cachedPeer == nil
return true, "", s.runFastPathSync(ctx, reqStart, syncStart, accountID, peerKey, peer, updates, cached.Serial, peerMetaHash, upgradeCache, srv, unlock)
}
// commitFastPath subscribes the peer to network-map updates and marks it
// connected. When cachedPeer is non-nil (cache hit with a complete entry),
// the expensive GetPeerByPeerPubKey store call is skipped and the cached
// snapshot is used instead.
//
// It relies on the same eventual-consistency guarantee as the slow path: a
// concurrent writer's broadcast may race the subscription, but any subsequent
// serial change reaches the subscribed peer via its update channel, and a
// reconnect with a stale cached serial falls through to the slow path on the
// next Sync. Returns committed=false on any failure that should not block
// the slow path from running.
func (s *Server) commitFastPath(
ctx context.Context,
accountID string,
peerKey wgtypes.Key,
realIP net.IP,
syncStart time.Time,
cachedPeer *nbpeer.Peer,
) (*nbpeer.Peer, chan *network_map.UpdateMessage, bool) {
commitStart := time.Now()
defer func() {
log.WithContext(ctx).Debugf("fast path: commitFastPath took %s", time.Since(commitStart))
}()
var peer *nbpeer.Peer
if cachedPeer != nil {
peer = cachedPeer
log.WithContext(ctx).Debugf("fast path: GetPeerByPeerPubKey skipped (cache hit)")
} else {
getPeerStart := time.Now()
p, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
log.WithContext(ctx).Debugf("fast path: lookup peer %s: %v", peerKey.String(), err)
return nil, nil, false
}
log.WithContext(ctx).Debugf("fast path: GetPeerByPeerPubKey took %s", time.Since(getPeerStart))
peer = p
}
onConnectedStart := time.Now()
updates, err := s.networkMapController.OnPeerConnectedWithPeer(ctx, accountID, peer)
if err != nil {
log.WithContext(ctx).Debugf("fast path: notify peer connected for %s: %v", peerKey.String(), err)
return nil, nil, false
}
log.WithContext(ctx).Debugf("fast path: OnPeerConnectedWithPeer took %s", time.Since(onConnectedStart))
s.markPeerConnectedAsync(peerKey.String(), realIP, accountID, syncStart)
return peer, updates, true
}
// markPeerConnectedAsync fires MarkPeerConnected in a detached goroutine so
// the Sync hot path does not wait on a DB write that can spike into the
// multi-second range under contention. LastSeen becomes eventually-consistent
// by at most one write; the peer's next Sync or the per-peer expiration
// routines correct any drift. Concurrent fast-path Syncs for the same peer
// coalesce to a single background write via the inflight map.
func (s *Server) markPeerConnectedAsync(peerKey string, realIP net.IP, accountID string, syncStart time.Time) {
if _, loaded := s.inflightMarkPeerConnected.LoadOrStore(peerKey, struct{}{}); loaded {
log.Debugf("fast path: async MarkPeerConnected for %s coalesced (already in flight)", peerKey)
return
}
go func() {
defer s.inflightMarkPeerConnected.Delete(peerKey)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
start := time.Now()
if err := s.accountManager.MarkPeerConnected(ctx, peerKey, true, realIP, accountID, syncStart); err != nil {
log.Warnf("fast path: async MarkPeerConnected for %s: %v", peerKey, err)
return
}
log.Debugf("fast path: async MarkPeerConnected for %s took %s", peerKey, time.Since(start))
}()
}
// runFastPathSync executes the fast path: send the lean response, kick off
// token refresh, release the per-peer lock, then block on handleUpdates until
// the stream is closed. Peer lookup and subscription have already been
// performed by commitFastPath so the race between eligibility check and
// subscription is already closed.
func (s *Server) runFastPathSync(
ctx context.Context,
reqStart, syncStart time.Time,
accountID string,
peerKey wgtypes.Key,
peer *nbpeer.Peer,
updates chan *network_map.UpdateMessage,
serial uint64,
peerMetaHash uint64,
upgradeCache bool,
srv proto.ManagementService_SyncServer,
unlock *func(),
) error {
sendStart := time.Now()
if err := s.sendFastPathResponse(ctx, peerKey, peer, srv); err != nil {
log.WithContext(ctx).Debugf("fast path: send response for peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
return err
}
log.WithContext(ctx).Debugf("fast path: sendFastPathResponse took %s", time.Since(sendStart))
// Upgrade a legacy-shape cache entry (Serial + MetaHash only, pre step 2)
// to the full shape so the next Sync's lookupPeerAuthFromCache +
// commitFastPath can actually short-circuit the pre-fast-path
// GetPeerAuthInfo and GetPeerByPeerPubKey. Only runs when the peer was
// freshly fetched from the store this Sync — rewriting from a cached
// snapshot would lose HasUser because PeerSnapshot doesn't carry UserID.
if upgradeCache {
s.writePeerSyncEntry(peerKey.String(), serial, peerMetaHash, peer)
}
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
if unlock != nil && *unlock != nil {
(*unlock)()
*unlock = nil
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
s.appMetrics.GRPCMetrics().CountFastPathSync()
}
log.WithContext(ctx).Debugf("Sync (fast path) took %s", time.Since(reqStart))
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, peerMetaHash, updates, srv, syncStart)
}
// sendFastPathResponse builds a NetbirdConfig-only SyncResponse, encrypts it
// with the peer's WireGuard key and pushes it over the stream.
func (s *Server) sendFastPathResponse(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, srv proto.ManagementService_SyncServer) error {
resp := buildFastPathResponse(ctx, s.config, s.secretsManager, s.fetchExtraSettings, s.fetchPeerGroups, peer)
key, err := s.secretsManager.GetWGKey()
if err != nil {
return status.Errorf(codes.Internal, "failed getting server key")
}
body, err := encryption.EncryptMessage(peerKey, key, resp)
if err != nil {
return status.Errorf(codes.Internal, "error encrypting fast-path sync response")
}
if err := srv.Send(&proto.EncryptedMessage{
WgPubKey: key.PublicKey().String(),
Body: body,
}); err != nil {
log.WithContext(ctx).Errorf("failed sending fast-path sync response: %v", err)
return status.Errorf(codes.Internal, "error handling request")
}
return nil
}
// fetchPeerGroups returns a cached list of group IDs for the peer when
// available, falling back to the account manager's store on miss. Populates
// the cache on miss so subsequent fast-path Syncs hit it.
func (s *Server) fetchPeerGroups(ctx context.Context, accountID, peerID string) ([]string, error) {
if ids, ok := s.peerGroupsCache.get(peerID); ok {
log.WithContext(ctx).Debugf("fast path: GetPeerGroupIDs skipped (cache hit)")
return ids, nil
}
start := time.Now()
ids, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
log.WithContext(ctx).Debugf("fast path: GetPeerGroupIDs took %s", time.Since(start))
s.peerGroupsCache.set(peerID, ids)
return ids, nil
}
// recordPeerSyncEntry writes the serial just delivered to this peer so a
// subsequent reconnect can take the fast path. Called after the slow path's
// sendInitialSync has pushed a full map. A nil cache disables the fast path.
// peer is required so the cached entry carries the snapshot fields the
// pure-cache fast path needs (AccountID, PeerID, Key, Ephemeral, HasUser).
func (s *Server) recordPeerSyncEntry(peerKey string, netMap *nbtypes.NetworkMap, peerMetaHash uint64, peer *nbpeer.Peer) {
if netMap == nil || netMap.Network == nil {
return
}
s.writePeerSyncEntry(peerKey, netMap.Network.CurrentSerial(), peerMetaHash, peer)
}
// recordPeerSyncEntryFromUpdate is the sendUpdate equivalent of
// recordPeerSyncEntry: it extracts the serial from a streamed NetworkMap update
// so the cache stays in sync with what the peer most recently received.
func (s *Server) recordPeerSyncEntryFromUpdate(peerKey string, update *network_map.UpdateMessage, peerMetaHash uint64, peer *nbpeer.Peer) {
if update == nil || update.Update == nil || update.Update.NetworkMap == nil {
return
}
s.writePeerSyncEntry(peerKey, update.Update.NetworkMap.GetSerial(), peerMetaHash, peer)
}
// writePeerSyncEntry is the common cache write used by every path that
// delivers state to a peer: the slow-path sendInitialSync, streamed
// NetworkMap updates, and the fast path itself. Writing from the fast path
// upgrades legacy-shape entries (Serial + MetaHash only, pre step 2) to the
// full shape on the next successful Sync so subsequent cache hits can
// actually short-circuit GetPeerAuthInfo and GetPeerByPeerPubKey.
func (s *Server) writePeerSyncEntry(peerKey string, serial, peerMetaHash uint64, peer *nbpeer.Peer) {
if s.peerSerialCache == nil {
return
}
if !s.fastPathFlag.Enabled() {
return
}
if serial == 0 {
return
}
s.peerSerialCache.Set(peerKey, newPeerSyncEntry(serial, peerMetaHash, peer))
}
// newPeerSyncEntry builds a cache entry with every field the pure-cache
// fast path needs. peer may be nil (very old call sites), in which case the
// entry is written without the snapshot fields and will fail IsComplete().
func newPeerSyncEntry(serial, metaHash uint64, peer *nbpeer.Peer) peerSyncEntry {
entry := peerSyncEntry{
Serial: serial,
MetaHash: metaHash,
}
if peer != nil {
entry.AccountID = peer.AccountID
entry.PeerID = peer.ID
entry.PeerKey = peer.Key
entry.Ephemeral = peer.Ephemeral
entry.HasUser = peer.UserID != ""
}
return entry
}
// invalidatePeerSyncEntry is called after a successful Login so the next Sync
// is guaranteed to deliver a full map, picking up whatever changed in the
// login (SSH key rotation, approval state, user binding, etc.).
func (s *Server) invalidatePeerSyncEntry(peerKey string) {
if s.peerSerialCache == nil {
return
}
s.peerSerialCache.Delete(peerKey)
}

View File

@@ -0,0 +1,163 @@
package grpc
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
func fastPathTestPeer() *nbpeer.Peer {
return &nbpeer.Peer{
ID: "peer-id",
AccountID: "account-id",
Key: "pubkey",
}
}
func fastPathTestSecrets(t *testing.T, turn *config.TURNConfig, relay *config.Relay) *TimeBasedAuthSecretsManager {
t.Helper()
peersManager := update_channel.NewPeersUpdateManager(nil)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMock := settings.NewMockManager(ctrl)
secrets, err := NewTimeBasedAuthSecretsManager(peersManager, turn, relay, settingsMock, groups.NewManagerMock())
require.NoError(t, err, "secrets manager initialisation must succeed")
return secrets
}
func noGroupsFetcher(context.Context, string, string) ([]string, error) {
return nil, nil
}
func TestBuildFastPathResponse_TimeBasedTURNAndRelay_FreshTokens(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
turnCfg := &config.TURNConfig{
CredentialsTTL: ttl,
Secret: "turn-secret",
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}
relayCfg := &config.Relay{
Addresses: []string{"rel.example:443"},
CredentialsTTL: ttl,
Secret: "relay-secret",
}
cfg := &config.Config{
TURNConfig: turnCfg,
Relay: relayCfg,
Signal: &config.Host{URI: "signal.example:443", Proto: config.HTTPS},
Stuns: []*config.Host{{URI: "stun.example:3478", Proto: config.UDP}},
}
secrets := fastPathTestSecrets(t, turnCfg, relayCfg)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), "account-id").Return(&types.ExtraSettings{}, nil).AnyTimes()
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock.GetExtraSettings, noGroupsFetcher, fastPathTestPeer())
require.NotNil(t, resp, "response must not be nil")
assert.Nil(t, resp.NetworkMap, "fast path must omit NetworkMap")
assert.Nil(t, resp.PeerConfig, "fast path must omit PeerConfig")
assert.Empty(t, resp.Checks, "fast path must omit posture checks")
assert.Empty(t, resp.RemotePeers, "fast path must omit remote peers")
require.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must be present on fast path")
require.Len(t, resp.NetbirdConfig.Turns, 1, "time-based TURN credentials must be present")
assert.NotEmpty(t, resp.NetbirdConfig.Turns[0].User, "TURN user must be populated")
assert.NotEmpty(t, resp.NetbirdConfig.Turns[0].Password, "TURN password must be populated")
require.NotNil(t, resp.NetbirdConfig.Relay, "Relay config must be present when configured")
assert.NotEmpty(t, resp.NetbirdConfig.Relay.TokenPayload, "relay token payload must be populated")
assert.NotEmpty(t, resp.NetbirdConfig.Relay.TokenSignature, "relay token signature must be populated")
assert.Equal(t, []string{"rel.example:443"}, resp.NetbirdConfig.Relay.Urls, "relay URLs passthrough")
require.NotNil(t, resp.NetbirdConfig.Signal, "Signal config must be present when configured")
assert.Equal(t, "signal.example:443", resp.NetbirdConfig.Signal.Uri, "signal URI passthrough")
require.Len(t, resp.NetbirdConfig.Stuns, 1, "STUNs must be passed through")
assert.Equal(t, "stun.example:3478", resp.NetbirdConfig.Stuns[0].Uri, "STUN URI passthrough")
}
func TestBuildFastPathResponse_StaticTURNCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
staticHost := &config.Host{
URI: "turn:static.example:3478",
Proto: config.UDP,
Username: "preset-user",
Password: "preset-pass",
}
turnCfg := &config.TURNConfig{
CredentialsTTL: ttl,
Secret: "turn-secret",
Turns: []*config.Host{staticHost},
TimeBasedCredentials: false,
}
cfg := &config.Config{TURNConfig: turnCfg}
// Use a relay-free secrets manager; static TURN path does not consult it.
secrets := fastPathTestSecrets(t, turnCfg, nil)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock.GetExtraSettings, noGroupsFetcher, fastPathTestPeer())
require.NotNil(t, resp.NetbirdConfig)
require.Len(t, resp.NetbirdConfig.Turns, 1, "static TURN must appear in response")
assert.Equal(t, "preset-user", resp.NetbirdConfig.Turns[0].User, "static user passthrough")
assert.Equal(t, "preset-pass", resp.NetbirdConfig.Turns[0].Password, "static password passthrough")
assert.Nil(t, resp.NetbirdConfig.Relay, "no Relay when Relay config is nil")
}
func TestBuildFastPathResponse_NoRelayConfigured_NoRelaySection(t *testing.T) {
cfg := &config.Config{}
secrets := fastPathTestSecrets(t, nil, nil)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock.GetExtraSettings, noGroupsFetcher, fastPathTestPeer())
require.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must be non-nil even without relay/turn")
assert.Nil(t, resp.NetbirdConfig.Relay, "Relay must be absent when not configured")
assert.Empty(t, resp.NetbirdConfig.Turns, "Turns must be empty when not configured")
}
func TestBuildFastPathResponse_ExtraSettingsErrorStillReturnsResponse(t *testing.T) {
cfg := &config.Config{}
secrets := fastPathTestSecrets(t, nil, nil)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(nil, assertAnError).AnyTimes()
resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock.GetExtraSettings, noGroupsFetcher, fastPathTestPeer())
require.NotNil(t, resp, "extra settings failure must degrade gracefully into an empty fast-path response")
assert.Nil(t, resp.NetworkMap, "NetworkMap still omitted on degraded path")
}
// assertAnError is a sentinel used by fast-path tests that need to simulate a
// dependency failure without caring about the error value.
var assertAnError = errForTests("simulated")
type errForTests string
func (e errForTests) Error() string { return string(e) }

View File

@@ -39,11 +39,8 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
@@ -327,7 +324,7 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context,
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
return nil
}
@@ -335,7 +332,7 @@ func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
return nil
}
@@ -351,6 +348,18 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *testValidateSessionProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}

View File

@@ -30,6 +30,7 @@ import (
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/fastpathcache"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -111,6 +112,11 @@ type DefaultAccountManager struct {
permissionsManager permissions.Manager
disableDefaultPolicy bool
// sharedCacheStore is retained so mutation paths can invalidate the
// Sync fast-path caches (ExtraSettings, peer-groups) without a circular
// dependency on the gRPC server package that owns the read-side wrappers.
sharedCacheStore cacheStore.StoreInterface
}
var _ account.Manager = (*DefaultAccountManager)(nil)
@@ -181,7 +187,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
return modified, newUserAutoGroups, newGroupsToCreate, nil
}
// BuildManager creates a new DefaultAccountManager with a provided Store
// BuildManager creates a new DefaultAccountManager with all dependencies.
func BuildManager(
ctx context.Context,
config *nbconfig.Config,
@@ -199,6 +205,7 @@ func BuildManager(
settingsManager settings.Manager,
permissionsManager permissions.Manager,
disableDefaultPolicy bool,
sharedCacheStore cacheStore.StoreInterface,
) (*DefaultAccountManager, error) {
start := time.Now()
defer func() {
@@ -247,16 +254,13 @@ func BuildManager(
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
}
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
if err != nil {
return nil, fmt.Errorf("getting cache store: %s", err)
}
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
am.externalCacheManager = nbcache.NewUserDataCache(sharedCacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, sharedCacheStore)
am.sharedCacheStore = sharedCacheStore
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
go func() {
err := am.warmupIDPCache(ctx, cacheStore)
err := am.warmupIDPCache(ctx, sharedCacheStore)
if err != nil {
log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err)
// todo retry?
@@ -371,6 +375,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if err != nil {
return nil, err
}
if extraSettingsChanged {
fastpathcache.InvalidateExtraSettings(ctx, am.sharedCacheStore, accountID)
}
am.handleRoutingPeerDNSResolutionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
@@ -2290,3 +2297,9 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
func (am *DefaultAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
return am.Store.GetUserIDByPeerKey(ctx, store.LockingStrengthNone, peerKey)
}
// GetPeerAuthInfo returns the userID and accountID for a peer in a single
// store call. Used by the Sync hot path to collapse two lookups into one.
func (am *DefaultAccountManager) GetPeerAuthInfo(ctx context.Context, peerKey string) (string, string, error) {
return am.Store.GetPeerAuthInfoByPubKey(ctx, store.LockingStrengthNone, peerKey)
}

View File

@@ -134,6 +134,9 @@ type Manager interface {
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error)
// GetPeerAuthInfo returns the userID and accountID for a peer in a single
// store call. Used by the Sync hot path to collapse two lookups into one.
GetPeerAuthInfo(ctx context.Context, peerKey string) (userID, accountID string, err error)
GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error)
CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)

View File

@@ -900,6 +900,22 @@ func (mr *MockManagerMockRecorder) GetPeer(ctx, accountID, peerID, userID interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeer", reflect.TypeOf((*MockManager)(nil).GetPeer), ctx, accountID, peerID, userID)
}
// GetPeerAuthInfo mocks base method.
func (m *MockManager) GetPeerAuthInfo(ctx context.Context, peerKey string) (string, string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerAuthInfo", ctx, peerKey)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(string)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetPeerAuthInfo indicates an expected call of GetPeerAuthInfo.
func (mr *MockManagerMockRecorder) GetPeerAuthInfo(ctx, peerKey interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAuthInfo", reflect.TypeOf((*MockManager)(nil).GetPeerAuthInfo), ctx, peerKey)
}
// GetPeerGroups mocks base method.
func (m *MockManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
m.ctrl.T.Helper()

Some files were not shown because too many files have changed in this diff Show More