Compare commits

...

41 Commits

Author SHA1 Message Date
Dmitri
b21f7f7d6a updated event aggregation test
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-12 15:14:17 +02:00
Dmitri
98ce097ecb update test to validate event aggregation over tcp, udp, icmp, and icmpv6
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-11 15:34:03 +02:00
Dmitri
598558c77e Merge remote-tracking branch 'origin/main' into dmitri-event-aggregation 2026-06-11 13:32:28 +02:00
Dmitri
d9d585e1d4 pacifying linter
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-11 12:19:58 +02:00
Dmitri
a593e32a1d removed inadvertenly added google proto files
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-11 12:07:29 +02:00
Dmitri
12a8943b99 regenerated proto files
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-11 12:03:46 +02:00
Dmitri
42e0007f4a fixes based on sonarcube checks
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-11 10:18:08 +02:00
Maycon Santos
d7703767d5 [client, proxy] cancel context before stopping engine on embedded client (#6397)
- Engine.Start takes syncMsgMux with a deferred unlock (engine.go:445) and parks in receiveSignalEvents → WaitStreamConnected (engine.go:1762), which only wakes on
  signal-stream connect or client-context cancellation.
  - When signal never connects, the 30s startup timeout fires and embed.Client.Start's rollback (embed.go:281) called client.Stop() → Engine.Stop, which blocks acquiring
  syncMsgMux (engine.go:318). The cancel() that would unpark Start was deferred until Start returned — permanent cycle. RemovePeer calls (g43/g385) then queue behind the
  lifecycle mutex.
  - Notably, embed.Client.Stop and the daemon's cleanupConnection both cancel before stopping — the startup rollback was the only path that didn't.
  - Engine.Start takes syncMsgMux with a deferred unlock (engine.go:445) and parks in receiveSignalEvents → WaitStreamConnected (engine.go:1762), which only wakes on
  signal-stream connect or client-context cancellation.
  - When signal never connects, the 30s startup timeout fires and embed.Client.Start's rollback (embed.go:281) called client.Stop() → Engine.Stop, which blocks acquiring
  syncMsgMux (engine.go:318). The cancel() that would unpark Start was deferred until Start returned — permanent cycle. RemovePeer calls (g43/g385) then queue behind the
  lifecycle mutex.
  - Notably, embed.Client.Stop and the daemon's cleanupConnection both cancel before stopping — the startup rollback was the only path that didn't.
2026-06-10 21:26:54 +02:00
Maycon Santos
7feda907ca [management] fix L4 service update when no custom port (#6396)
This fixes an issue where L4 service update is not possible when proxy clusters don't support custom ports
2026-06-10 18:55:24 +02:00
Maycon Santos
62da482133 [management] Add version gate to stop sending deprecated RemotePeers field (#6371)
* [management] Add version gate to stop sending deprecated RemotePeers field

don't send top-level remote peers on peers in the  v0.29.3 or newer

* precompute deprecated remote peers version constraint

* [management] update tests to validate network map-based remote peers

* [management] move deprecatedRemotePeersVersion constant closer to its usage

* fix misplaced precomputed constraint definition

* ensure top-level RemotePeers is empty for v0.29.3+ clients
2026-06-10 16:59:09 +02:00
Dmitri
8f99362a25 added tracking of the number of start-, drop, and end-events in an aggregation window
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-10 16:06:29 +02:00
Philip Laine
079bce3c2f Add commands to discover and write Kubernetes configuration (#6260) 2026-06-10 15:00:10 +02:00
Maycon Santos
1a09aa6715 [misc] Update Go toolchain version in go.mod (#6377) 2026-06-10 14:50:57 +02:00
Dmitri
101ae3ca77 added manager integration test
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-10 14:48:58 +02:00
Maycon Santos
61abf5b9ea [proxy] Use UUID for proxy ID generation (#6391)
Use UUID for proxy ID instead of the second to avoid race conditions when running multiple nodes at the same time.
2026-06-10 13:35:26 +02:00
Boris Dolgov
e229050ba3 [proxy] Notify certificate ready for domains covered by the static certificate (#6389) 2026-06-10 12:05:34 +02:00
Zoltan Papp
e919b2d55d [client] Preserve posture checks on config-only sync updates (#6373)
* [client] Preserve posture checks on config-only sync updates

When management sends a MessageTypeControlConfig update (e.g. relay token
rotation), the SyncResponse carries no NetworkMap and no Checks. Moving the
updateChecksIfNew call after the nm == nil guard ensures posture checks are
only updated when a full network map is present, preventing relay token
rotation from silently clearing the previously applied posture check state.

* [client] Clarify posture check update logic with explicit comment

* [client] Extract NetBird config and sync persistence into helpers

Move the NetbirdConfig handling block out of handleSync into
updateNetbirdConfig and the sync response persistence into
persistSyncResponse, mirroring updateChecksIfNew. This flattens
handleSync and makes the individual update steps unit-testable.
2026-06-10 11:43:24 +02:00
Dmitri
b654a75a43 added tcp-aggregation test
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-10 10:33:37 +02:00
Dmitri
243e93477f initial support for aggregation of events
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-09 15:54:39 +02:00
Pascal Fischer
a40028092d [management] log user agent and return request id (#6380) 2026-06-09 15:24:26 +02:00
Pascal Fischer
13200265d8 [proxy] Add no-blocking mapping updates (#6369) 2026-06-09 13:57:17 +02:00
Viktor Liu
ed7a9363aa [management] Emit IPv6 default permit firewall rule for exit node routes (#6368) 2026-06-09 13:26:43 +02:00
Viktor Liu
d56859dc5d [client] Filter DNS fallback upstreams matching our server IP to prevent loops (#6183) 2026-06-09 12:26:03 +02:00
Dmitri
60bcf7dfc3 added an implementation of aggregating memory store
Signed-off-by: Dmitri <dmitri.external@netbird.io>
2026-06-09 11:38:02 +02:00
Viktor Liu
367d37050b [relay, client] Fall back to WebSocket relay transport on oversized QUIC datagrams (#6339) 2026-06-09 10:25:46 +02:00
Viktor Liu
106527182f [client] Snapshot iptables rule maps before persisting state (#6345) 2026-06-09 10:24:51 +02:00
Viktor Liu
8e1d5b78c2 [client] Preserve user deselect-all across management route sync (#6363) 2026-06-09 10:24:17 +02:00
PizzaLovingNerd
d3b63c6be9 [infrastructure] Better support for atomic distros in install.sh, docker fixes in getting-started.sh (#6139)
* Made the docker check first for getting-started.sh, better atomic support for install.sh

* Check for docker socket perms

* Added fallback for systems without rpm-ostree or bootc.

* macOS fix for docker socket check

* Change error message for docker group.

No longer using a blanket recommendation for the docker group.
2026-06-08 21:38:46 +02:00
Maycon Santos
60d2fa08b0 [client] Mask sensitive data in debug bundle creation (#6364)
* [client] Mask sensitive data in debug bundle creation

* Avoid nil reference in turn and use masked constant
2026-06-08 13:17:04 +02:00
Maycon Santos
1e7b16db0a [management] resolve private services on custom domains in synthesized DNS zones (#6348)
private services on a custom domain didn't resolve on clients — the synthesized DNS zone was anchored to the cluster, and the account's custom domains weren't even
  loaded.

- account.go — SynthesizePrivateServiceZones now keys zones by a resolved apex (privateServiceDomainZone): cluster suffix → registered account.Domains (filtered by matching
  TargetCluster, longest wins) → skip if none. One zone per apex; custom-domain services group under their registered domain.
- sql_store.go — GetAccount now loads account.Domains on both loaders (gorm Preload("Domains") + pgx goroutine via ListCustomDomains; errChan buffer bumped 12→16). This was
  the reason the deploy didn't work — the relation was empty in prod.
- Tests — custom-domain zone synthesis cases (apex resolution, free+custom separation, sibling collapse, cluster mismatch, mixed cluster/custom/public) + GetAccount
  domain-preload tests on sqlite and Postgres.
2026-06-06 12:56:01 +02:00
Maycon Santos
b377d99933 [management] Copy private field on shallowCloneMapping (#6347)
* [management] Copy private field on shallowCloneMapping

added test to ensure clone handles new fields

* Remove unnecessary debug logs from proxy service

* Increase Wasm binary size limit to 60MB in build validation
2026-06-05 22:45:49 +02:00
Theodor Midtlien
512899d82d [client] Prevent corruption from competing log rotation and improve debug bundle (#6214)
* Adds heuristic to detect an edge case on Linux where a system has configured logrotate as a separate service to rotate log files which would mangle our client log files. If we detect logrotate being configured for netbird, we disable our rotation.

* Adds new env var to disable log rotation: NB_LOG_DISABLE_ROTATION

* Adds compressed and plain logrotate files to debug bundle.

* Replaces lumberjack with timberjack (maintained fork with bug fixes and extra features).

* Clarifies which daemon version is running in the bundle stats.

* Change logging for client service status to console
2026-06-04 17:36:45 +02:00
Theodor Midtlien
5993ec6e43 [client] Allow wireguard port to be zero in UI and show port in status command (#6158)
* Allow wireguard port to be set to 0 in UI

* Add wireguard port to cmd status

* Correct protoc version
2026-06-04 15:04:11 +02:00
Maycon Santos
eac6d501c3 [infrastructure] allow docker image overrides for getting started (#6335)
* [infrastructure] allow docker image overrides for getting started

Make dashboard and server image configurations overrideable via environment variables

* [infrastructure] update Traefik gRPC rule to include ProxyService PathPrefix

* make Traefik and CrowdSec images configurable via environment variables
2026-06-04 11:24:47 +02:00
Maycon Santos
deeae30612 [misc] Add Codecov integration and coverage reporting across workflows (#6333) 2026-06-03 19:08:45 +02:00
Bethuel Mmbaga
f3cdf163e1 [management] Export ResolveDomain (#6334) 2026-06-03 19:53:57 +03:00
Zoltan Papp
3e61ccb162 [client] Persist sync response via pluggable store (disk on iOS) (#6331)
* Persist sync response via pluggable store (disk on iOS)

The latest Management sync response (which carries the network map) was
kept in memory for debug bundle generation. On memory-constrained
platforms like iOS the network map can be large enough to matter.

Introduce a syncstore package with a Store interface and two backends:
a memory backend (the previous behavior) and a disk backend that
serializes the response to a file in the state directory. The backend
is selected per-platform at build time: disk on iOS, memory elsewhere.

The disk store clears any leftover file on construction so a fresh
store never reads stale data from an earlier run (e.g. another
profile's network map).

In the engine, drop the separate persistSyncResponse bool: the store is
only instantiated while persistence is enabled, and its presence is
what marks persistence as active. The store is also cleared on engine
close so the file does not linger on disk.

* syncstore: silence nilnil linter on "nothing stored" returns

Get returns (nil, nil) to signal that nothing is stored, which is part
of the Store contract and preserves the original behaviour. Annotate
both backends with //nolint:nilnil so golangci-lint does not flag it.

* syncstore: hold syncRespMux for the whole store Set/Get

Both handleSync and GetLatestSyncResponse snapshotted e.syncStore under
the read lock and then released it before calling Set/Get. That allowed
SetSyncResponsePersistence(false) or engine close to clear the store
mid-call. In particular a concurrent Clear()+nil followed by a late
Set could re-create the file that was just removed, defeating the
leak/lingering protection.

Hold syncRespMux for the duration of the store operation in both spots
so the store cannot be cleared while a Set/Get is in flight.

* syncstore: avoid StateDir "." when state path is empty

On mobile the state path may be empty (the engine tolerates a missing
state file). filepath.Dir("") returns ".", which would make a
disk-backed syncstore write into the working directory instead of
letting NewDiskStore fall back to os.TempDir().

Only set engineConfig.StateDir when path is non-empty.
2026-06-03 14:18:50 +02:00
Viktor Liu
a48c20d8d8 [client] Gate DNS forwarder on BlockInbound (#6257) 2026-06-03 11:33:29 +02:00
Riccardo Manfrin
2b57a7d43b [client, management, misc] expose VCS revision in dev build version output (#6263)
* Refactor to use a common checker for development version

* Adds commit sha to development version for cobra command only

Leave dashboard unaffected

* Adjust for "v0.31.1-dev" test case

which must be considered pre-release

* Drop synthetic "dev"/"0.50.0-dev" firewall feature-gate fixtures

These test cases encoded the loose strings.Contains(v, "dev")
semantics inherited from peerSupportedFirewallFeatures, but
NetbirdVersion() never produces those values — only the literal
"development" (and now "development-<sha>[-dirty]") ever flows
through the wire. The agent owns the semantics of an ephemeral
development build, so the tests should exercise the strings we
actually emit.

Replaced with development, development-<sha> and
development-<sha>-dirty cases that match the HasPrefix("development")
predicate introduced upstream.

* Remove unexistent tests on wire format

The sha / dirty flag are added only when the CLI asks the version.
Account versions is unaffacted and can only strictly match "development"

* Adds tests for IsDevelopmentVersion
2026-06-03 08:56:50 +02:00
Maycon Santos
fa1e241aea [management, client, proxy] Follow-up fixes for private reverse-proxy services (#6268)
* fix(proxy): gate tunnel-peer fast-path on inbound listener marker

forwardWithTunnelPeer previously accepted any RFC1918 / ULA / CGNAT
source IP, so a public client whose address happened to fall in those
ranges could bypass the configured operator auth scheme by colliding
with a known tunnel IP. The fast-path is now gated on
TunnelLookupFromContext(r.Context()) being present — that context value
is attached only by the per-account inbound (overlay) listener, so the
host-facing listener never enters this branch.

Tests updated to reflect the new requirement: requests that don't
carry the inbound marker now fall through to the regular auth flow.

* fix(proxy): harden inbound listener resource + startup-ctx handling

Three correctness fixes on the per-account inbound path, with tests:

- Close the logrus ErrorLog PipeWriter on tearDown. WriterLevel hands
  back an *io.PipeWriter backed by a pipe + scanner goroutine that the
  caller owns; the two writers per account (https + plain) were never
  closed, leaking the pipe and goroutine on every teardown.
- Run the post-Start hooks on context.Background(). runClientStartup
  is launched in a goroutine from AddPeer and was inheriting the
  caller's request-scoped ctx, so a cancelled request could abort the
  inbound bring-up or fail the management status notification. The
  tail is split into notifyClientReady so the contract is testable.

Tests cover the PipeWriter close behaviour and assert the readyHandler
+ NotifyStatus calls receive a non-cancelled background context.

* feat(proxy): short-circuit peer-own-target loops with 421

When a peer that hosts the target of a private service dials its own
service URL the request was being looped through the proxy and back
over WireGuard to the same peer — twice the WG round-trip for no
benefit, with no signal to the caller that something was wrong.

Add isSelfTargetLoop to ReverseProxy.ServeHTTP: when the request
arrived on the per-account overlay listener (IsOverlayOrigin) and the
source tunnel IP matches the target host, refuse the request with 421
Misdirected Request and a body pointing the operator at the backend
directly.

The gate is scoped to overlay origin so requests on the public
listener that happen to share a source IP with the target host are
forwarded normally.

* fix(management): private-service validation + tunnel-IP lookup semantics

- Require an explicit port for L4 cluster targets. validateL4Target
  exempted TargetTypeCluster from the port check, but buildPathMappings
  serializes every L4 target via net.JoinHostPort(host, port) — port=0
  shipped a ":0" upstream. Cluster targets use the same Host/Port
  fields, so the same requirement applies.
- GetPeerByIP returns NotFound on a tunnel-IP miss instead of mapping
  every error to Internal. The proxy's ValidateTunnelPeer probes IPs
  that legitimately aren't in the roster; the miss is expected and now
  distinguishable from a real store failure.
- Thread ctx into getClusterCapability's gorm query so a cancelled
  request doesn't keep the store busy.

Tests updated for the L4-cluster port requirement and the GetPeerByIP
NotFound path.

* fix(client): include offlinePeers in PeerStateByIP lookup

ReplaceOfflinePeers moves peers into d.offlinePeers but PeerStateByIP
only scanned d.peers. Callers (the local DNS filter via
localPeerConnectivity, embed.Client.IdentityForIP used by the
proxy's tunnel-peer validator) were treating known-but-offline peers
as unknown, which:

- causes the DNS filter to keep returning records pointing at peers
  that have no live tunnel, AND
- makes the proxy's local-roster check deny a request from such a
  peer rather than letting the cached management RPC carry the
  authorisation decision.

Search both slices in PeerStateByIP. Adds a unit test for the IPv4
and IPv6 offline-match paths.

* fix(rest): reject empty Delete path params in reverse-proxy clients

ReverseProxyClustersAPI.Delete and ReverseProxyTokensAPI.Delete passed
the path parameter into url.PathEscape without an empty check.
PathEscape("") returns "" which collapses the request onto the
collection endpoint ("/api/reverse-proxies/clusters/" /
"/api/reverse-proxies/proxy-tokens/"), so a caller bug delete with no
id reached a routable URL with surprising semantics (typically 405).

Short-circuit with a typed error before the request is built. Tests
mount a handler on the collection path that fails the test if hit, so
the regression is impossible to reintroduce silently.

* chore(api,ci,docs,test): private-service schema, proto-check, fixups

Non-functional cleanups and contract/CI hardening around the
private-service work:

API schema (openapi.yml):
- Require a non-empty access_groups and mode=http when private=true,
  on both Service and ServiceRequest, mirroring
  validatePrivateRequirements. mode stays optional-but-constrained
  (empty defaults to http server-side), matching runtime.

CI (proto-version-check.yml):
- Cover renamed .pb.go files (read base via previous_filename).
- Match protoc-gen-go-grpc version headers (optional "- " prefix and
  -gen-go-grpc suffix) so grpc-generated files are in scope.

Docs / comments:
- Reword Config field docs to say defaults are applied at Server.Start
  (initDefaults), not New.
- Rename the obsolete --private-inbound flag to --private across
  comments and the proto doc.

Pre-existing test fixups surfaced by review:
- Repair the integration-tagged validate_session_test.go (SignToken
  signature growth + new Manager interface methods).
- Fix the CI-skip boolean precedence so Windows isn't skipped
  unconditionally.
- Guard the router.HTTPListener type assertion with comma-ok.

* fix(proxy): background ctx for already-started AddPeer notification

The earlier ctx fix covered the async runClientStartup path but missed
the synchronous branch: when a service is added to an already-started
client, AddPeer called NotifyStatus with the caller's request-scoped
ctx. A cancelled request/stream could drop the connected notification
to management. Use context.Background() here too, matching
notifyClientReady.

Extends TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus to
pass a pre-cancelled caller ctx and assert the notification still ran
on a non-cancelled context.

* use the cmd context for roundtripper
2026-06-02 13:40:09 +02:00
Viktor Liu
e7c9182ff9 [client] Offer injected ICMPv6 echo replies to packet capture (#6321) 2026-06-01 19:38:00 +02:00
128 changed files with 5622 additions and 729 deletions

View File

@@ -45,4 +45,11 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,client

View File

@@ -158,7 +158,16 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,client
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -276,9 +285,17 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-exec 'sudo' -coverprofile=coverage.txt \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,relay
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
@@ -326,7 +343,15 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,proxy
test_signal:
name: "Signal / Unit"
@@ -377,9 +402,17 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-exec 'sudo' -coverprofile=coverage.txt \
-timeout 10m ./signal/... ./shared/signal/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,signal
test_management:
name: "Management / Unit"
needs: [build-cache]
@@ -445,10 +478,18 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \
go test -tags=devcert -coverprofile=coverage.txt \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... ./shared/management/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: unit,management
benchmark:
name: "Management / Benchmark"
needs: [build-cache]
@@ -687,6 +728,14 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \
go test -tags=integration -coverprofile=coverage.txt \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/server/http/...
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird
flags: integration,management

View File

@@ -20,15 +20,30 @@ jobs:
per_page: 100,
});
const modifiedPbFiles = files.filter(
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
);
if (modifiedPbFiles.length === 0) {
console.log('No modified .pb.go files to check');
// Cover renamed .pb.go files in addition to plain edits.
// Renamed entries land under the new path with previous_filename
// pointing at the base-side name, so we read the base content
// from the old path when present.
const changedPbFiles = files
.filter(f => (f.status === 'modified' || f.status === 'renamed')
&& f.filename.endsWith('.pb.go'))
.map(f => ({
headPath: f.filename,
basePath: f.previous_filename || f.filename,
}));
if (changedPbFiles.length === 0) {
console.log('No modified or renamed .pb.go files to check');
return;
}
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
// Matches the generator version headers protoc writes at the top
// of generated files:
// // protoc v3.21.12
// // protoc-gen-go v1.26.0
// // - protoc-gen-go-grpc v1.6.1 (grpc files prefix with "- ")
// The optional "- " prefix and the optional -gen-go / -gen-go-grpc
// suffixes keep the *_grpc.pb.go headers in scope.
const versionPattern = /^\s*\/\/\s+(?:-\s+)?protoc(?:-gen-go(?:-grpc)?)?\s+v[\d.]+/;
const baseSha = context.payload.pull_request.base.sha;
const headSha = context.payload.pull_request.head.sha;
@@ -55,20 +70,22 @@ jobs:
}
const violations = [];
for (const file of modifiedPbFiles) {
for (const file of changedPbFiles) {
const [base, head] = await Promise.all([
getVersionHeader(file.filename, baseSha),
getVersionHeader(file.filename, headSha),
getVersionHeader(file.basePath, baseSha),
getVersionHeader(file.headPath, headSha),
]);
if (!base.ok || !head.ok) {
core.warning(
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
`Skipping ${file.headPath}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
);
continue;
}
if (base.lines.join('\n') !== head.lines.join('\n')) {
violations.push({
file: file.filename,
file: file.basePath === file.headPath
? file.headPath
: `${file.basePath} → ${file.headPath}`,
base: base.lines,
head: head.lines,
});

View File

@@ -29,10 +29,10 @@ jobs:
persist-credentials: false
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
run: bash -x release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
run: bash -x release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff

View File

@@ -65,7 +65,7 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
if [ ${SIZE} -gt 62914560 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
exit 1
fi

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/version"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -100,6 +101,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
Anonymize: anonymizeFlag,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
CliVersion: version.NetbirdVersion(),
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
@@ -298,6 +300,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
Anonymize: anonymizeFlag,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
CliVersion: version.NetbirdVersion(),
}
if uploadBundleFlag {
request.UploadURL = uploadBundleURLFlag
@@ -432,6 +435,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
SyncResponse: syncResponse,
LogPath: logFilePath,
CPUProfile: nil,
DaemonVersion: version.NetbirdVersion(), // acting as daemon
},
debug.BundleConfig{
IncludeSystemInfo: true,

301
client/cmd/kubernetes.go Normal file
View File

@@ -0,0 +1,301 @@
package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"github.com/goccy/go-yaml"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
const (
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
)
var kubernetesCmd = &cobra.Command{
Use: "kubernetes",
Short: "Kubernetes cluster commands.",
Long: "Kubernetes cluster commands.",
}
var kubernetesListCmd = &cobra.Command{
Use: "list",
RunE: kubernetesList,
Short: "List Kubernetes clusters.",
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
}
var kubernetesWriteKubeconfigCmd = &cobra.Command{
Use: "write-kubeconfig",
RunE: kubernetesWriteKubeconfig,
Args: cobra.ExactArgs(1),
Short: "Write kubeconfig for a Kubernetes cluster.",
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
}
func init() {
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
}
func kubernetesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
if err != nil {
return err
}
if len(kcs) == 0 {
cmd.Println("No Kubernetes clusters available.")
return nil
}
cmd.Println("Available Kubernetes clusters:")
for _, k := range kcs {
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
}
return nil
}
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
kubeconfigPath, err := resolveKubeconfigPath(cmd)
if err != nil {
return err
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
clusterName := args[0]
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
if err != nil {
return err
}
if len(kcs) == 0 {
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
}
if len(kcs) > 1 {
return fmt.Errorf("too many Kubernetes clusters returned")
}
err = writeKubeconfig(kubeconfigPath, kcs[0])
if err != nil {
return err
}
return nil
}
type kubernetesCluster struct {
name string
url *url.URL
version string
}
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
httpClient := &http.Client{
Transport: transport,
}
resolver := net.Resolver{
// Required so both DNS records are returned.
// https://github.com/golang/go/issues/17093
PreferGo: true,
}
kcs := []kubernetesCluster{}
attempted := map[string]struct{}{}
for _, peer := range peers {
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
if err != nil {
return nil, err
}
for _, fqdn := range fqdns {
if _, ok := attempted[fqdn]; ok {
continue
}
attempted[fqdn] = struct{}{}
comps := strings.Split(fqdn, ".")
if len(comps) < 2 {
continue
}
if comps[1] != KubernetesDNSSuffix {
continue
}
if nameFilter != "" && nameFilter != comps[0] {
continue
}
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
if err != nil {
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
continue
}
kc := kubernetesCluster{
name: comps[0],
url: clusterURL,
version: clusterVersion,
}
if nameFilter != "" {
return []kubernetesCluster{kc}, nil
}
kcs = append(kcs, kc)
}
}
return kcs, nil
}
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
clusterURL, err := url.Parse("https://" + fqdn)
if err != nil {
return nil, "", err
}
versionURL, err := clusterURL.Parse("/version")
if err != nil {
return nil, "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
if err != nil {
return nil, "", err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", err
}
versionData := map[string]string{}
err = json.Unmarshal(b, &versionData)
if err != nil {
return nil, "", err
}
version, ok := versionData["gitVersion"]
if !ok {
return nil, "", errors.New("no version found in response")
}
return clusterURL, version, nil
}
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
if cmd.Flags().Changed("kubeconfig") {
path, err := cmd.Flags().GetString("kubeconfig")
if err != nil {
return "", err
}
return path, nil
}
if env := os.Getenv("KUBECONFIG"); env != "" {
return env, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return filepath.Join(home, ".kube", "config"), nil
}
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
b, err := os.ReadFile(kubeconfigPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
var cfg map[string]any
if err := yaml.Unmarshal(b, &cfg); err != nil {
return err
}
if cfg == nil {
cfg = map[string]any{
"apiVersion": "v1",
"kind": "Config",
}
}
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
"name": kc.name,
"cluster": map[string]any{
"server": kc.url.String(),
"insecure-skip-tls-verify": true,
},
})
cfg["users"] = appendWithName(cfg["users"], map[string]any{
"name": "netbird",
"user": map[string]any{
"token": "none",
},
})
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
"name": kc.name,
"context": map[string]any{
"cluster": kc.name,
"user": "netbird",
"namespace": "default",
},
})
cfg["current-context"] = kc.name
out, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
return err
}
return nil
}
func appendWithName(data any, add map[string]any) any {
if data == nil {
return []any{add}
}
v, ok := data.([]any)
if !ok {
return []any{add}
}
i := slices.IndexFunc(v, func(item any) bool {
m, ok := item.(map[string]any)
if !ok {
return false
}
return m["name"] == add["name"]
})
if i == -1 {
return append(v, add)
}
v[i] = add
return v
}

View File

@@ -0,0 +1,120 @@
package cmd
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestFingerprintClusters(t *testing.T) {
t.Parallel()
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
w.Write([]byte(`{"gitVersion": "foobar"}`))
}))
defer srv.Close()
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
require.NoError(t, err)
require.Equal(t, srv.URL, clusterURL.String())
require.Equal(t, "foobar", clusterVersion)
}
func TestResolveKubeconfigPath(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
t.Fatalf("could not determine home directory: %v", err)
}
defaultPath := filepath.Join(home, ".kube", "config")
path, err := resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, defaultPath, path)
flagPath := "flag-path"
cmd := &cobra.Command{}
cmd.Flags().String("kubeconfig", "", "")
err = cmd.Flags().Set("kubeconfig", flagPath)
require.NoError(t, err)
path, err = resolveKubeconfigPath(cmd)
require.NoError(t, err)
require.Equal(t, flagPath, path)
envPath := "env-path"
t.Setenv("KUBECONFIG", envPath)
path, err = resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, envPath, path)
}
func TestWriteKubeconfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existing string
}{
{
name: "empty file",
},
{
name: "existing content",
existing: `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://foobar.com
name: foo
current-context: test
kind: Config
users: []
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeconfigPath := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
require.NoError(t, err)
kc := kubernetesCluster{
name: "foo",
url: &url.URL{Scheme: "https", Host: "example.com"},
}
err = writeKubeconfig(kubeconfigPath, kc)
require.NoError(t, err)
b, err := os.ReadFile(kubeconfigPath)
require.NoError(t, err)
expected := `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://example.com
name: foo
contexts:
- context:
cluster: foo
namespace: default
user: netbird
name: foo
current-context: foo
kind: Config
users:
- name: netbird
user:
token: none
`
require.Equal(t, expected, string(b))
})
}
}

View File

@@ -169,6 +169,11 @@ func init() {
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
// kubernetes commands
rootCmd.AddCommand(kubernetesCmd)
kubernetesCmd.AddCommand(kubernetesListCmd)
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)

View File

@@ -102,7 +102,7 @@ func (p *program) Stop(srv service.Service) error {
}
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc, consoleLog bool) (service.Service, error) {
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(serviceCmd)
@@ -112,8 +112,14 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
return nil, err
}
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
if consoleLog {
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
} else {
if err := util.InitLog(logLevel, logFiles...); err != nil {
return nil, fmt.Errorf("init log: %w", err)
}
}
cfg, err := newSVCConfig()
@@ -138,7 +144,7 @@ var runCmd = &cobra.Command{
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -152,7 +158,7 @@ var startCmd = &cobra.Command{
Short: "starts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -170,7 +176,7 @@ var stopCmd = &cobra.Command{
Short: "stops NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -188,7 +194,7 @@ var restartCmd = &cobra.Command{
Short: "restarts NetBird service",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
if err != nil {
return err
}
@@ -206,7 +212,7 @@ var svcStatusCmd = &cobra.Command{
Short: "shows NetBird service status",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
s, err := setupServiceControlCommand(cmd, ctx, cancel)
s, err := setupServiceControlCommand(cmd, ctx, cancel, true)
if err != nil {
return err
}

View File

@@ -12,7 +12,13 @@ var (
Short: "Print the NetBird's client application version",
Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout())
cmd.Println(version.NetbirdVersion())
out := version.NetbirdVersion()
if version.IsDevelopmentVersion(out) {
if commit := version.NetbirdCommit(); commit != "" {
out += "-" + commit
}
}
cmd.Println(out)
},
}
)

View File

@@ -279,6 +279,10 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}

168
client/embed/embed_test.go Normal file
View File

@@ -0,0 +1,168 @@
package embed
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
// holding the engine mutex. When the Start context expires, the rollback path
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
func TestClientStartTimeoutRollback(t *testing.T) {
signalAddr := startBlackholeSignal(t)
mgmAddr := startManagement(t, signalAddr)
wgPort := 0
client, err := New(Options{
DeviceName: "embed-rollback-test",
SetupKey: testSetupKey,
ManagementURL: "http://" + mgmAddr,
WireguardPort: &wgPort,
})
require.NoError(t, err, "embed client creation must succeed")
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
startErr := make(chan error, 1)
go func() {
startErr <- client.Start(startCtx)
}()
select {
case err := <-startErr:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(60 * time.Second):
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
}
}
// startBlackholeSignal starts a gRPC server without the SignalExchange service
// registered. Connections succeed, but the signal stream can never be
// established, which keeps Engine.Start parked in WaitStreamConnected.
func startBlackholeSignal(t *testing.T) string {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}
func startManagement(t *testing.T, signalAddr string) string {
t.Helper()
cfg := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: t.TempDir(),
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
require.NoError(t, err)
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
jobManager := job.NewJobManager(nil, testStore, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
require.NoError(t, err)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
require.NoError(t, err)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
require.NoError(t, err)
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
require.NoError(t, err)
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}

View File

@@ -3,6 +3,7 @@ package iptables
import (
"errors"
"fmt"
"maps"
"net"
"slices"
@@ -421,12 +422,17 @@ func (m *aclManager) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the maps so the persisted state holds a private snapshot. The
// live maps keep being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing them by reference races the two and aborts the process with a
// concurrent map iteration and write.
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
currentState.ACLEntries6 = maps.Clone(m.entries)
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
currentState.ACLEntries = maps.Clone(m.entries)
currentState.ACLIPsetStore = m.ipsetStore.clone()
}
if err := m.stateManager.UpdateState(currentState); err != nil {

View File

@@ -4,6 +4,7 @@ package iptables
import (
"fmt"
"maps"
"net/netip"
"strconv"
"strings"
@@ -749,11 +750,17 @@ func (r *router) updateState() {
currentState.Lock()
defer currentState.Unlock()
// Clone the rule map so the persisted state holds a private snapshot. The
// live map keeps being mutated by subsequent rule operations while the
// state manager marshals the state from its periodic-save goroutine.
// Sharing it by reference races the two and aborts the process with a
// concurrent map iteration and write. The ipset counter guards itself
// during marshaling, so it can be shared directly.
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteRules6 = maps.Clone(r.rules)
currentState.RouteIPsetCounter6 = r.ipsetCounter
} else {
currentState.RouteRules = r.rules
currentState.RouteRules = maps.Clone(r.rules)
currentState.RouteIPsetCounter = r.ipsetCounter
}

View File

@@ -1,6 +1,9 @@
package iptables
import "encoding/json"
import (
"encoding/json"
"maps"
)
type ipList struct {
ips map[string]struct{}
@@ -19,6 +22,14 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// clone returns a deep copy of the ipList with its own ips map.
func (s *ipList) clone() *ipList {
if s == nil {
return nil
}
return &ipList{ips: maps.Clone(s.ips)}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
@@ -55,6 +66,19 @@ func newIpsetStore() *ipsetStore {
}
}
// clone returns a deep copy of the ipsetStore with its own ipsets map and
// independent ipList entries.
func (s *ipsetStore) clone() *ipsetStore {
if s == nil {
return nil
}
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
for name, list := range s.ipsets {
cloned.ipsets[name] = list.clone()
}
return cloned
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok

View File

@@ -362,6 +362,10 @@ func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload
return 0
}
if pc := f.endpoint.capture.Load(); pc != nil {
(*pc).Offer(fullPacket, true)
}
return len(fullPacket)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/netip"
"path/filepath"
"runtime"
"runtime/debug"
"strings"
@@ -346,6 +347,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
// Leave StateDir empty when there is no state path so a disk-backed
// syncstore falls back to os.TempDir() instead of filepath.Dir("") == ".".
if path != "" {
engineConfig.StateDir = filepath.Dir(path)
}
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -254,6 +254,8 @@ type BundleGenerator struct {
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
daemonVersion string
cliVersion string
anonymize bool
includeSystemInfo bool
@@ -278,6 +280,8 @@ type GeneratorDependencies struct {
CapturePath string
RefreshStatus func()
ClientMetrics MetricsExporter
DaemonVersion string
CliVersion string
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -299,6 +303,8 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
daemonVersion: deps.DaemonVersion,
cliVersion: deps.CliVersion,
anonymize: cfg.Anonymize,
includeSystemInfo: cfg.IncludeSystemInfo,
@@ -459,9 +465,11 @@ func (g *BundleGenerator) addStatus() error {
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
Anonymize: g.anonymize,
ProfileName: profName,
Anonymize: g.anonymize,
ProfileName: profName,
DaemonVersion: g.daemonVersion,
})
overview.CliVersion = g.cliVersion
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
@@ -798,6 +806,8 @@ func (g *BundleGenerator) addSyncResponse() error {
AllowPartial: true,
}
g.maskSecrets()
jsonBytes, err := options.Marshal(g.syncResponse)
if err != nil {
return fmt.Errorf("generate json: %w", err)
@@ -810,6 +820,27 @@ func (g *BundleGenerator) addSyncResponse() error {
return nil
}
func (g *BundleGenerator) maskSecrets() {
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
return
}
if g.syncResponse.NetbirdConfig.Flow != nil {
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
}
if g.syncResponse.NetbirdConfig.Relay != nil {
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
}
for i := range g.syncResponse.NetbirdConfig.Turns {
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
}
}
}
func (g *BundleGenerator) addStateFile() error {
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
@@ -1039,7 +1070,8 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
return
}
pattern := filepath.Join(logDir, "client-*.log.gz")
// This regex will match both logs rotated by us and logrotate on linux
pattern := filepath.Join(logDir, "client*.log.*")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
@@ -1072,7 +1104,12 @@ func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if err := g.addSingleLogFileGz(files[i], name); err != nil {
if strings.HasSuffix(name, ".gz") {
err = g.addSingleLogFileGz(files[i], name)
} else {
err = g.addSingleLogfile(files[i], name)
}
if err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}

View File

@@ -0,0 +1,103 @@
package debug
import (
"archive/zip"
"bytes"
"compress/gzip"
"io"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestAddRotatedLogFiles_PicksUpAllVariants asserts that the rotated-log
// glob picks up logs rotated by timberjack (gzipped) and by logrotate (plain
// and gzipped), and skips unrelated files.
func TestAddRotatedLogFiles_PicksUpAllVariants(t *testing.T) {
dir := t.TempDir()
writeFile(t, filepath.Join(dir, "client.log"), "active log\n")
writeFile(t, filepath.Join(dir, "other.log"), "unrelated\n")
timberjackRotated := "client-2026-05-21T10-30-45.000.log.gz"
writeGzFile(t, filepath.Join(dir, timberjackRotated), "timberjack rotated content\n")
logrotatePlain := "client.log.1"
writeFile(t, filepath.Join(dir, logrotatePlain), "logrotate plain content\n")
logrotateGz := "client.log.2.gz"
writeGzFile(t, filepath.Join(dir, logrotateGz), "logrotate gz content\n")
names := runAddRotatedLogFiles(t, dir, 10)
require.Contains(t, names, timberjackRotated, "timberjack rotated file should be in bundle")
require.Contains(t, names, logrotatePlain, "logrotate plain rotated file should be in bundle")
require.Contains(t, names, logrotateGz, "logrotate gzipped rotated file should be in bundle")
require.NotContains(t, names, "client.log", "active log should not be added by addRotatedLogFiles")
require.NotContains(t, names, "other.log", "unrelated files should not be in bundle")
}
// TestAddRotatedLogFiles_RespectsLogFileCount asserts that only the newest
// logFileCount rotated files are bundled, ordered by mtime.
func TestAddRotatedLogFiles_RespectsLogFileCount(t *testing.T) {
dir := t.TempDir()
oldest := filepath.Join(dir, "client.log.3")
middle := filepath.Join(dir, "client.log.2")
newest := filepath.Join(dir, "client.log.1")
writeFile(t, oldest, "old\n")
writeFile(t, middle, "mid\n")
writeFile(t, newest, "new\n")
now := time.Now()
require.NoError(t, os.Chtimes(oldest, now.Add(-2*time.Hour), now.Add(-2*time.Hour)))
require.NoError(t, os.Chtimes(middle, now.Add(-1*time.Hour), now.Add(-1*time.Hour)))
require.NoError(t, os.Chtimes(newest, now, now))
names := runAddRotatedLogFiles(t, dir, 2)
require.Contains(t, names, "client.log.1")
require.Contains(t, names, "client.log.2")
require.NotContains(t, names, "client.log.3", "oldest file should be dropped when logFileCount=2")
}
// runAddRotatedLogFiles calls addRotatedLogFiles against a fresh in-memory
// zip writer and returns the set of entry names that ended up in the archive.
func runAddRotatedLogFiles(t *testing.T, dir string, logFileCount uint32) map[string]struct{} {
t.Helper()
var buf bytes.Buffer
g := &BundleGenerator{
archive: zip.NewWriter(&buf),
logFileCount: logFileCount,
}
g.addRotatedLogFiles(dir)
require.NoError(t, g.archive.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
names := make(map[string]struct{}, len(zr.File))
for _, f := range zr.File {
names[f.Name] = struct{}{}
}
return names
}
func writeFile(t *testing.T, path, content string) {
t.Helper()
require.NoError(t, os.WriteFile(path, []byte(content), 0o644))
}
func writeGzFile(t *testing.T, path, content string) {
t.Helper()
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
_, err := io.WriteString(gw, content)
require.NoError(t, err)
require.NoError(t, gw.Close())
require.NoError(t, os.WriteFile(path, buf.Bytes(), 0o644))
}

View File

@@ -777,13 +777,24 @@ func (s *DefaultServer) applyHostConfig() {
// context is released rather than leaked until GC.
func (s *DefaultServer) registerFallback() {
originalNameservers := s.hostManager.getOriginalNameservers()
if len(originalNameservers) == 0 {
serverIP := s.service.RuntimeIP()
var servers []netip.AddrPort
for _, ns := range originalNameservers {
if ns == serverIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
continue
}
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
if len(servers) == 0 {
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
s.clearFallback()
return
}
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
handler, err := newUpstreamResolver(
s.ctx,
@@ -797,11 +808,6 @@ func (s *DefaultServer) registerFallback() {
return
}
handler.selectedRoutes = s.selectedRoutes
var servers []netip.AddrPort
for _, ns := range originalNameservers {
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
}
handler.addRace(servers)
prev := s.fallbackHandler

View File

@@ -22,7 +22,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/proto"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
@@ -56,6 +55,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/syncstore"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
@@ -72,6 +72,7 @@ import (
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
"github.com/netbirdio/netbird/version"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -148,6 +149,10 @@ type EngineConfig struct {
LogPath string
TempDir string
// StateDir is the directory holding the state file. The sync response
// (network map) is serialized here on platforms that persist it to disk.
StateDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -226,10 +231,15 @@ type Engine struct {
afpacketCapture *capture.AFPacketCapture
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
// Sync response persistence (protected by syncRespMux).
// syncStore is nil unless persistence has been enabled; its presence is
// what marks persistence as active. The backend (disk or memory) is
// selected per-platform; see the syncstore package. syncStoreDir is where
// a disk-backed store serializes to.
syncRespMux sync.RWMutex
syncStore syncstore.Store
syncStoreDir string
flowManager nftypes.FlowManager
// auto-update
@@ -292,6 +302,7 @@ func NewEngine(
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
syncStoreDir: config.StateDir,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -869,63 +880,25 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return err
}
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
err = e.handleRelayUpdate(wCfg.GetRelay())
if err != nil {
return err
}
err = e.handleFlowUpdate(wCfg.GetFlow())
if err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
// Posture checks are bound to the network map presence:
// NetworkMap != nil, checks present -> apply the received checks
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if enabled {
e.syncRespMux.Lock()
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
e.persistSyncResponse(update)
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -937,6 +910,64 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
// which is the case for sync updates carrying only a network map.
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
if wCfg == nil {
return nil
}
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
return err
}
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
return nil
}
// persistSyncResponse stores the full sync response so it can be restored on the next
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
if e.syncStore == nil {
return
}
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
return
}
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too
@@ -1063,6 +1094,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.FQDN = conf.GetFqdn()
state.WgPort = e.config.WgPort
e.statusRecorder.UpdateLocalPeerState(state)
@@ -1141,6 +1173,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
DaemonVersion: version.NetbirdVersion(),
RefreshStatus: func() {
e.RunHealthProbes(true)
},
@@ -1813,6 +1846,18 @@ func (e *Engine) close() {
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
// Drop any persisted sync response so its network map does not linger on
// disk after the engine stops (and cannot leak into a later run).
e.syncRespMux.Lock()
store := e.syncStore
e.syncStore = nil
e.syncRespMux.Unlock()
if store != nil {
if err := store.Clear(); err != nil {
log.Warnf("failed to clear persisted sync response on close: %v", err)
}
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -2142,45 +2187,42 @@ func (e *Engine) stopDNSServer() {
e.statusRecorder.UpdateDNSStates(nsGroupStates)
}
// SetSyncResponsePersistence enables or disables sync response persistence
// SetSyncResponsePersistence enables or disables sync response persistence.
// The store is only instantiated while persistence is enabled; construction
// itself drops any stale data left over from an earlier run (see syncstore).
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
if enabled == e.persistSyncResponse {
if enabled == (e.syncStore != nil) {
return
}
e.persistSyncResponse = enabled
log.Debugf("Sync response persistence is set to %t", enabled)
if !enabled {
e.latestSyncResponse = nil
if err := e.syncStore.Clear(); err != nil {
log.Warnf("failed to clear persisted sync response: %v", err)
}
e.syncStore = nil
return
}
e.syncStore = syncstore.New(e.syncStoreDir)
}
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
// Hold the lock for the whole Get so the store cannot be cleared
// (disabled / engine close) mid-call.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
defer e.syncRespMux.RUnlock()
if !enabled {
if e.syncStore == nil {
return nil, errors.New("sync response persistence is disabled")
}
if latest == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}
return sr, nil
//nolint:nilnil
return e.syncStore.Get()
}
// GetWgAddr returns the wireguard address
@@ -2216,7 +2258,7 @@ func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
) {
if e.config.DisableServerRoutes {
if e.config.DisableServerRoutes || e.config.BlockInbound {
return
}

View File

@@ -4,6 +4,8 @@ import (
"strings"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
)
var (
@@ -11,7 +13,7 @@ var (
)
func IsSupported(agentVersion string) bool {
if agentVersion == "development" {
if nbversion.IsDevelopmentVersion(agentVersion) {
return true
}

View File

@@ -27,7 +27,7 @@ type Logger struct {
wgIfaceNetV6 netip.Prefix
dnsCollection atomic.Bool
exitNodeCollection atomic.Bool
Store types.Store
Store types.AggregatingStore
}
func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger {
@@ -35,7 +35,7 @@ func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix)
statusRecorder: statusRecorder,
wgIfaceNet: wgIfaceIPNet,
wgIfaceNetV6: wgIfaceIPNetV6,
Store: store.NewMemoryStore(),
Store: store.NewAggregatingMemoryStore(),
}
}
@@ -125,6 +125,10 @@ func (l *Logger) stop() {
l.mux.Unlock()
}
func (l *Logger) ResetAggregationWindow() types.FlowEventAggregator {
return l.Store.ResetAggregationWindow()
}
func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents()
}

View File

@@ -9,12 +9,14 @@ import (
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger"
"github.com/netbirdio/netbird/client/internal/netflow/store"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client"
@@ -23,14 +25,16 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
receiverClient *client.GRPCClient
publicKey []byte
cancel context.CancelFunc
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
receiverClient *client.GRPCClient
eventsWithoutAcks nftypes.Store
publicKey []byte
cancel context.CancelFunc
retryInterval time.Duration
}
// NewManager creates a new netflow manager
@@ -48,9 +52,11 @@ func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *pee
}
return &Manager{
logger: flowLogger,
conntrack: ct,
publicKey: publicKey,
logger: flowLogger,
conntrack: ct,
publicKey: publicKey,
retryInterval: time.Second,
eventsWithoutAcks: store.NewMemoryStore(),
}
}
@@ -107,7 +113,7 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
m.shutdownWg.Add(2)
m.shutdownWg.Add(3)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
@@ -116,6 +122,10 @@ func (m *Manager) resetClient() error {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
go func() {
defer m.shutdownWg.Done()
m.startRetries(ctx)
}()
return nil
}
@@ -207,13 +217,15 @@ func (m *Manager) startSender(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
events := m.logger.GetEvents()
collectedEvents := m.logger.ResetAggregationWindow()
events := collectedEvents.GetAggregatedEvents()
for _, event := range events {
if err := m.send(event); err != nil {
log.Errorf("failed to send flow event to server: %v", err)
continue
} else {
log.Tracef("sent flow event: %s", event.ID)
}
log.Tracef("sent flow event: %s", event.ID)
m.eventsWithoutAcks.StoreEvent(event)
}
}
}
@@ -227,7 +239,7 @@ func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
return nil
}
log.Tracef("received flow event ack: %s", id)
m.logger.DeleteEvents([]uuid.UUID{id})
m.eventsWithoutAcks.DeleteEvents([]uuid.UUID{id})
return nil
})
@@ -236,6 +248,41 @@ func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
}
}
func (m *Manager) startRetries(ctx context.Context) {
ticker := time.NewTimer(m.retryInterval)
retryBackoff := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 1 * time.Second,
RandomizationFactor: 0.5,
Multiplier: 1.7,
MaxInterval: m.flowConfig.Interval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
for _, e := range m.eventsWithoutAcks.GetEvents() {
if e.Timestamp.Add(time.Second).After(time.Now()) {
// grace period on retries to avoid early retries
// do not retry if the event is less than 1 sec old
continue
}
if err := m.send(e); err != nil {
ticker = time.NewTimer(retryBackoff.NextBackOff()) //nolint:staticcheck,wastedassign
break
}
}
retryBackoff.Reset()
ticker = time.NewTimer(time.Second)
}
}
}
func (m *Manager) send(event *nftypes.Event) error {
m.mux.Lock()
client := m.receiverClient

View File

@@ -0,0 +1,291 @@
package netflow
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/flow/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
handlerDone chan struct{} // signaled each time Events() exits
handlerStarted chan struct{} // signaled each time Events() begins
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
handlerDone: make(chan struct{}, 10),
handlerStarted: make(chan struct{}, 10),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
go func() {
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Logf("server error: %v", err)
}
}()
t.Cleanup(func() {
s.grpcSrv.Stop()
})
return s
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
defer func() {
select {
case s.handlerDone <- struct{}{}:
default:
}
}()
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
select {
case s.handlerStarted <- struct{}{}:
default:
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
go func() {
defer cancel()
for {
event, err := stream.Recv()
if err != nil {
return
}
if !event.IsInitiator {
select {
case s.events <- event:
case <-ctx.Done():
return
}
}
}
}()
for {
select {
case ack := <-s.acks:
if err := stream.Send(ack); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func TestSendEventReceiveAck(t *testing.T) {
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
server := newTestServer(t)
manager := createManager(t, server.addr, 60*time.Second) // set high to prevent retries in this test
defer manager.Close()
assert.Eventually(t, func() bool {
select {
case <-server.handlerStarted:
return true
default:
return false
}
}, 3*time.Second, 100*time.Millisecond)
event1 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.2"),
DestPort: 2345,
Protocol: 6,
}
manager.logger.StoreEvent(event1)
event2 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.1"),
DestPort: 1234,
Protocol: 6,
}
manager.logger.StoreEvent(event2)
// verify the server received logged events
serverSideEvents := make([]*proto.FlowEvent, 0)
assert.Eventually(t, func() bool {
select {
case event := <-server.events:
serverSideEvents = append(serverSideEvents, event)
if len(serverSideEvents) == 2 {
return true
}
default:
if len(serverSideEvents) == 2 {
return true
}
}
return false
}, 5*time.Second, 100*time.Millisecond)
serverSideFlowIds := make([]uuid.UUID, 0, 2)
slices.Values(serverSideEvents)(func(e *proto.FlowEvent) bool {
id, err := uuid.FromBytes(e.FlowFields.FlowId)
assert.NoError(t, err)
serverSideFlowIds = append(serverSideFlowIds, id)
return true
})
assert.ElementsMatch(t, []uuid.UUID{event1.FlowID, event2.FlowID}, serverSideFlowIds)
// verify the manager tracks un-acked events
unackedEvents := manager.eventsWithoutAcks.GetEvents()
assert.Len(t, unackedEvents, 2)
flowIds := make([]uuid.UUID, 0)
slices.Values(unackedEvents)(func(e *types.Event) bool {
flowIds = append(flowIds, e.FlowID)
return true
})
assert.ElementsMatch(t, flowIds, []uuid.UUID{event1.FlowID, event2.FlowID})
}
// verify handling of retries:
// - unacked events are retried
// - when acks arrive, events are removed from the un-acked event tracker
func TestRetryEvents(t *testing.T) {
_, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
server := newTestServer(t)
manager := createManager(t, server.addr, time.Second) // set low to start retries sooner
defer manager.Close()
assert.Eventually(t, func() bool {
select {
case <-server.handlerStarted:
return true
default:
return false
}
}, 3*time.Second, 100*time.Millisecond)
event1 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.2"),
DestPort: 2345,
Protocol: 6,
}
manager.logger.StoreEvent(event1)
event2 := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
DestIP: ipAddr("172.16.1.1"),
DestPort: 1234,
Protocol: 6,
}
manager.logger.StoreEvent(event2)
// verify the server received retries of logged events
serverSideEvents := make([]*proto.FlowEvent, 0)
func() {
c := time.After(2500 * time.Millisecond)
for {
select {
case event := <-server.events:
serverSideEvents = append(serverSideEvents, event)
case <-c:
return
}
}
}()
assert.True(t, len(serverSideEvents) > 2) // must see retries
uniqueServerSideEvents := make(map[uuid.UUID]*proto.FlowEvent)
slices.Values(serverSideEvents)(func(e *proto.FlowEvent) bool {
id, err := uuid.FromBytes(e.FlowFields.FlowId)
assert.NoError(t, err)
uniqueServerSideEvents[id] = e
return true
})
assert.Contains(t, uniqueServerSideEvents, event1.FlowID)
assert.Contains(t, uniqueServerSideEvents, event2.FlowID)
// ack events
server.acks <- &proto.FlowEventAck{EventId: uniqueServerSideEvents[event1.FlowID].EventId}
server.acks <- &proto.FlowEventAck{EventId: uniqueServerSideEvents[event2.FlowID].EventId}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
unackedEvents := manager.eventsWithoutAcks.GetEvents()
assert.Empty(c, unackedEvents)
}, 3*time.Second, 100*time.Millisecond)
}
func createManager(t *testing.T, serverAddr string, retryInterval time.Duration) *Manager {
t.Helper()
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: netip.MustParsePrefix("192.168.1.1/32"),
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
manager := NewManager(mockIFace, publicKey, nil)
manager.retryInterval = retryInterval
initialConfig := &types.FlowConfig{
Enabled: true,
URL: fmt.Sprintf("http://%s", serverAddr),
TokenPayload: "initial-payload",
TokenSignature: "initial-signature",
Interval: 500 * time.Millisecond,
}
err := manager.Update(initialConfig)
require.NoError(t, err)
return manager
}
func ipAddr(a string) netip.Addr {
addr, _ := netip.ParseAddr(a)
return addr
}

View File

@@ -0,0 +1,190 @@
package store
import (
"math/rand"
"net/netip"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/stretchr/testify/assert"
)
var random = rand.New(rand.NewSource(time.Now().UnixNano()))
func TestFlowAggregation(t *testing.T) {
var protocols = []types.Protocol{types.ICMP, types.ICMPv6, types.TCP, types.UDP}
var tests = []struct {
description string
eventTypes []types.Type
}{
{
description: "start and stop",
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
eventTypes: []types.Type{types.TypeDrop},
}}
for _, protocol := range protocols {
for _, tt := range tests {
t.Run(tt.description+" "+protocol.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
allExpected := make([]*types.Event, 0)
for i := 0; i < 2; i++ {
inEvents, expected := generateEvents(tt.eventTypes, protocol, types.Ingress, 0)
for _, e := range inEvents {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected)
}
events := store.GetAggregatedEvents()
assert.ElementsMatch(t, events, allExpected)
})
}
}
}
func TestIcmpEventAggregation(t *testing.T) {
var protocols = []types.Protocol{types.ICMP, types.ICMPv6}
var icmpTypes = []uint8{1, 2, 3}
var tests = []struct {
description string
eventTypes []types.Type
}{
{
description: "start and stop",
eventTypes: []types.Type{types.TypeStart, types.TypeEnd},
},
{
description: "start and drop",
eventTypes: []types.Type{types.TypeStart, types.TypeDrop},
},
{
description: "start only",
eventTypes: []types.Type{types.TypeStart},
},
{
description: "drop only",
eventTypes: []types.Type{types.TypeDrop},
}}
for _, protocol := range protocols {
for _, tt := range tests {
t.Run(tt.description+" "+protocol.String(), func(t *testing.T) {
store := NewAggregatingMemoryStore()
allExpected := make([]*types.Event, 0)
for _, icmpType := range icmpTypes {
events, expected := generateEvents(tt.eventTypes, protocol, types.Ingress, icmpType)
for _, e := range events {
store.StoreEvent(e)
}
allExpected = append(allExpected, expected)
}
aggregatedEvents := store.GetAggregatedEvents()
assert.Len(t, aggregatedEvents, len(allExpected))
assert.ElementsMatch(t, aggregatedEvents, allExpected)
})
}
}
}
func ipAddr(a string) netip.Addr {
addr, _ := netip.ParseAddr(a)
return addr
}
func generateEvents(eventTypes []types.Type, protocol types.Protocol, direction types.Direction, icmpType uint8) ([]*types.Event, *types.Event) {
var rxPackets, txPackets, rxBytes, txBytes uint64
inEvents := make([]*types.Event, 0)
ts := time.Now()
flowId := uuid.New()
srcIp := ipAddr("1.1.1.1")
srcPort := uint16(random.Uint32() >> 16)
dstIp := ipAddr("2.2.2.2")
dstPort := uint16(random.Uint32() >> 16)
for idx, eventType := range eventTypes {
e := &types.Event{
ID: uuid.New(),
Timestamp: ts.Add(time.Duration(idx) * time.Second),
EventFields: types.EventFields{
FlowID: flowId,
Type: eventType,
Protocol: protocol,
RuleID: []byte("rule-id-1"),
Direction: direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: random.Uint64(),
TxPackets: random.Uint64(),
RxBytes: random.Uint64(),
TxBytes: random.Uint64(),
}}
rxBytes += e.RxBytes
txBytes += e.TxBytes
rxPackets += e.RxPackets
txPackets += e.TxPackets
inEvents = append(inEvents, e)
if protocol == types.ICMP || protocol == types.ICMPv6 {
e.ICMPType = icmpType
}
}
var start, end, drop uint64
for _, eventType := range eventTypes {
switch eventType {
case types.TypeStart:
start += 1
case types.TypeDrop:
drop += 1
case types.TypeEnd:
end += 1
}
}
aggregatedEvent := &types.Event{
ID: inEvents[0].ID,
Timestamp: inEvents[0].Timestamp,
EventFields: types.EventFields{
FlowID: flowId,
Type: inEvents[0].Type,
Protocol: inEvents[0].Protocol,
RuleID: []byte("rule-id-1"),
Direction: inEvents[0].Direction,
SourceIP: srcIp,
SourcePort: srcPort,
DestIP: dstIp,
DestPort: dstPort,
SourceResourceID: []byte("source-resource-id"),
DestResourceID: []byte("dest-resource-id"),
RxPackets: rxPackets,
TxPackets: txPackets,
RxBytes: rxBytes,
TxBytes: txBytes,
NumOfStarts: start,
NumOfEnds: end,
NumOfDrops: drop,
}}
if protocol == types.ICMP || protocol == types.ICMPv6 {
aggregatedEvent.ICMPType = icmpType
}
return inEvents, aggregatedEvent
}

View File

@@ -1,10 +1,13 @@
package store
import (
"maps"
"net/netip"
"slices"
"sync"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -19,6 +22,10 @@ type Memory struct {
events map[uuid.UUID]*types.Event
}
type AggregatingMemory struct {
Memory
}
func (m *Memory) StoreEvent(event *types.Event) {
m.mux.Lock()
defer m.mux.Unlock()
@@ -48,3 +55,78 @@ func (m *Memory) DeleteEvents(ids []uuid.UUID) {
delete(m.events, id)
}
}
func NewAggregatingMemoryStore() *AggregatingMemory {
return &AggregatingMemory{Memory{events: make(map[uuid.UUID]*types.Event)}}
}
func (am *AggregatingMemory) ResetAggregationWindow() types.FlowEventAggregator {
am.mux.Lock()
defer am.mux.Unlock()
toret := AggregatingMemory{Memory: Memory{events: am.events}}
am.events = make(map[uuid.UUID]*types.Event)
return &toret
}
type aggregationKey struct {
destAddr netip.Addr
destPort uint16
protocol uint8
icmpType uint8
unique int64 // used to prevent aggregation on non icmp/udp/tcp events
}
func (am *AggregatingMemory) GetAggregatedEvents() []*types.Event {
aggregated := make(map[aggregationKey]*types.Event)
for _, v := range am.events {
lookupKey := aggregationKey{destAddr: v.DestIP, destPort: v.DestPort, protocol: uint8(v.Protocol), icmpType: v.ICMPType}
if _, ok := aggregated[lookupKey]; !ok {
aggregated[lookupKey] = v.Clone()
event := aggregated[lookupKey]
if event.Protocol != types.ICMP && event.Protocol != types.ICMPv6 && event.Protocol != types.UDP && event.Protocol != types.TCP {
lookupKey.unique = time.Now().UnixNano() // to make the lookup key unique so we don't aggregate on it
continue
}
switch event.Type {
case types.TypeStart:
event.NumOfStarts += 1
case types.TypeDrop:
event.NumOfDrops += 1
case types.TypeEnd:
event.NumOfEnds += 1
}
continue
}
aggregatedEvent := aggregated[lookupKey]
if aggregatedEvent.Protocol != types.ICMP && aggregatedEvent.Protocol != types.ICMPv6 && aggregatedEvent.Protocol != types.UDP && aggregatedEvent.Protocol != types.TCP {
continue // we don't aggregate this type of events; shouldn't ever get here
}
// track the number of connections, duration?, open and close events?
aggregatedEvent.RxBytes += v.RxBytes
aggregatedEvent.RxPackets += v.RxPackets
aggregatedEvent.TxBytes += v.TxBytes
aggregatedEvent.TxPackets += v.TxPackets
switch v.Type {
case types.TypeStart:
aggregatedEvent.NumOfStarts += 1
case types.TypeDrop:
aggregatedEvent.NumOfDrops += 1
case types.TypeEnd:
aggregatedEvent.NumOfEnds += 1
}
if aggregatedEvent.Timestamp.Compare(v.Timestamp) > 0 {
aggregatedEvent.Timestamp = v.Timestamp
aggregatedEvent.ID = v.ID
aggregatedEvent.Type = v.Type
}
// do we aggregate icmp by code?
}
return slices.Collect(maps.Values(aggregated)) // could return an iterator instead here
}

View File

@@ -2,6 +2,7 @@ package types
import (
"net/netip"
"slices"
"strconv"
"time"
@@ -92,6 +93,17 @@ type EventFields struct {
TxPackets uint64
RxBytes uint64
TxBytes uint64
NumOfStarts uint64
NumOfEnds uint64
NumOfDrops uint64
}
func (e *Event) Clone() *Event {
toret := *e
toret.RuleID = slices.Clone(e.RuleID)
toret.SourceResourceID = slices.Clone(e.SourceResourceID)
toret.DestResourceID = slices.Clone(e.DestResourceID)
return &toret
}
type FlowConfig struct {
@@ -114,13 +126,15 @@ type FlowManager interface {
GetLogger() FlowLogger
}
type FlowEventAggregator interface {
ResetAggregationWindow() FlowEventAggregator
GetAggregatedEvents() []*Event
}
type FlowLogger interface {
ResetAggregationWindow() FlowEventAggregator
// StoreEvent stores a flow event
StoreEvent(flowEvent EventFields)
// GetEvents returns all stored events
GetEvents() []*Event
// DeleteEvents deletes events from the store
DeleteEvents([]uuid.UUID)
// Close closes the logger
Close()
// Enable enables the flow logger receiver
@@ -140,6 +154,11 @@ type Store interface {
Close()
}
type AggregatingStore interface {
FlowEventAggregator
Store
}
// ConnTracker defines the interface for connection tracking functionality
type ConnTracker interface {
// Start begins tracking connections by listening for conntrack events.

View File

@@ -111,6 +111,7 @@ type LocalPeerState struct {
PubKey string
KernelInterface bool
FQDN string
WgPort int
Routes map[string]struct{}
}
@@ -310,8 +311,12 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Returns the
// zero State and false when no peer matches or the input is empty.
// address so dual-stack peers are reachable on either family. Searches
// both d.peers and d.offlinePeers — peers that have been moved into
// the offline slice by ReplaceOfflinePeers are still part of the
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
// need to recognise them rather than treating them as unknown. Returns
// the zero State and false when no peer matches or the input is empty.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
@@ -324,6 +329,11 @@ func (d *Status) PeerStateByIP(ip string) (State, bool) {
return state, true
}
}
for _, state := range d.offlinePeers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
}
return State{}, false
}
@@ -1348,6 +1358,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.WgPort = int32(fs.LocalPeerState.WgPort)
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)

View File

@@ -90,6 +90,28 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
// been moved into the offline slice via ReplaceOfflinePeers. Callers
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
// rather than unknown — otherwise authentication / DNS filtering treats
// known-but-offline peers as foreign IPs.
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
status.ReplaceOfflinePeers([]State{
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
})
state, ok := status.PeerStateByIP("100.64.0.20")
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
state, ok = status.PeerStateByIP("fd00::20")
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {
key := "abc"
fqdn := "peer-a.netbird.local"

View File

@@ -700,6 +700,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
if m.routeSelector.IsDeselectAll() {
return
}
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
return

View File

@@ -0,0 +1,71 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
return route.HAMap{
haID: []*route.Route{
{
ID: "r-" + route.ID(netID),
NetID: netID,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Enabled: true,
SkipAutoApply: skipAutoApply,
},
},
}
}
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
})
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
})
t.Run("user selection is not overridden by management", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
routes := exitNodeRoutes("exit1", true)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
})
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
m.routeSelector.DeselectAllRoutes()
routes := exitNodeRoutes("exit1", false)
m.updateRouteSelectorFromManagement(routes)
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
})
}

View File

@@ -116,6 +116,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
clear(rs.selectedRoutes)
}
// IsDeselectAll reports whether the user has explicitly deselected all routes.
func (rs *RouteSelector) IsDeselectAll() bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.deselectAll
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()

View File

@@ -0,0 +1,99 @@
package syncstore
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
// syncResponseFileName is the name of the file the sync response is serialized
// to, placed inside the configured directory (the state directory).
const syncResponseFileName = "networkmap.pb"
// diskStore serializes the latest sync response to a file on disk instead of
// keeping it in memory. This trades disk I/O for a much smaller memory
// footprint, which matters on memory-constrained platforms (iOS).
type diskStore struct {
mu sync.Mutex
path string
}
// NewDiskStore returns a Store that serializes the sync response to a file in
// the given directory. If dir is empty it falls back to the OS temp directory.
//
// Any file left over from a previous run is removed on construction so a fresh
// store never reads stale data (e.g. another profile's network map).
func NewDiskStore(dir string) Store {
if dir == "" {
dir = os.TempDir()
}
s := &diskStore{
path: filepath.Join(dir, syncResponseFileName),
}
if err := s.Clear(); err != nil {
log.Warnf("failed to clear stale sync response file: %v", err)
}
return s
}
func (s *diskStore) Set(resp *mgmProto.SyncResponse) error {
if resp == nil {
return s.Clear()
}
bs, err := proto.Marshal(resp)
if err != nil {
return fmt.Errorf("marshal sync response: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
if err := util.WriteBytesWithRestrictedPermission(context.Background(), s.path, bs); err != nil {
return fmt.Errorf("write sync response to %s: %w", s.path, err)
}
log.Debugf("sync response persisted to %s (%d bytes)", s.path, len(bs))
return nil
}
func (s *diskStore) Get() (*mgmProto.SyncResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
bs, err := os.ReadFile(s.path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
//nolint:nilnil // nil,nil means "nothing stored", per the Store contract; preserve the original behaviour
return nil, nil
}
return nil, fmt.Errorf("read sync response from %s: %w", s.path, err)
}
resp := &mgmProto.SyncResponse{}
if err := proto.Unmarshal(bs, resp); err != nil {
return nil, fmt.Errorf("unmarshal sync response: %w", err)
}
log.Debugf("retrieving latest sync response from %s (%d bytes)", s.path, len(bs))
return resp, nil
}
func (s *diskStore) Clear() error {
s.mu.Lock()
defer s.mu.Unlock()
if err := os.Remove(s.path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("remove sync response file %s: %w", s.path, err)
}
return nil
}

View File

@@ -0,0 +1,9 @@
//go:build ios
package syncstore
// New returns the platform default store. On iOS the sync response is
// serialized to disk (in dir) to keep it out of the constrained process memory.
func New(dir string) Store {
return NewDiskStore(dir)
}

View File

@@ -0,0 +1,9 @@
//go:build !ios
package syncstore
// New returns the platform default store. On all non-iOS platforms the sync
// response is kept in memory; dir is unused.
func New(_ string) Store {
return NewMemoryStore()
}

View File

@@ -0,0 +1,56 @@
package syncstore
import (
"fmt"
"sync"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// memoryStore keeps the latest sync response in memory.
type memoryStore struct {
mu sync.RWMutex
latest *mgmProto.SyncResponse
}
// NewMemoryStore returns a Store that keeps the sync response in memory.
func NewMemoryStore() Store {
return &memoryStore{}
}
func (s *memoryStore) Set(resp *mgmProto.SyncResponse) error {
s.mu.Lock()
defer s.mu.Unlock()
s.latest = resp
return nil
}
func (s *memoryStore) Get() (*mgmProto.SyncResponse, error) {
s.mu.RLock()
latest := s.latest
s.mu.RUnlock()
if latest == nil {
//nolint:nilnil // nil,nil means "nothing stored", per the Store contract; preserve the original behaviour
return nil, nil
}
log.Debugf("retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("clone sync response")
}
return sr, nil
}
func (s *memoryStore) Clear() error {
s.mu.Lock()
defer s.mu.Unlock()
s.latest = nil
return nil
}

View File

@@ -0,0 +1,29 @@
// Package syncstore stores the latest Management sync response (which carries
// the network map) for debug bundle generation.
//
// The storage backend is selected at build time per operating system: on iOS
// the response is serialized to disk to keep it out of the (tightly
// constrained) process memory, while on all other platforms it is kept in
// memory. The backend is chosen by the New constructor; see factory_ios.go and
// factory_other.go.
package syncstore
import (
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Store persists the latest sync response and returns it on demand.
//
// Implementations must be safe for concurrent use.
type Store interface {
// Set stores the given sync response, replacing any previously stored one.
Set(resp *mgmProto.SyncResponse) error
// Get returns the stored sync response, or nil if none is stored.
// The returned value is an independent copy that the caller may retain.
Get() (*mgmProto.SyncResponse, error)
// Clear removes any stored sync response. It is safe to call when nothing
// is stored.
Clear() error
}

View File

@@ -19,8 +19,6 @@ import (
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
@@ -483,7 +481,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
}
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
if m.currentVersion == developmentVersion {
if version.IsDevelopmentVersion(m.currentVersion) {
log.Debugf("skipping auto-update, running development version")
return false
}

View File

@@ -1614,6 +1614,7 @@ type LocalPeerState struct {
RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"`
Ipv6 string `protobuf:"bytes,8,opt,name=ipv6,proto3" json:"ipv6,omitempty"`
WgPort int32 `protobuf:"varint,9,opt,name=wgPort,proto3" json:"wgPort,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1704,6 +1705,13 @@ func (x *LocalPeerState) GetIpv6() string {
return ""
}
func (x *LocalPeerState) GetWgPort() int32 {
if x != nil {
return x.WgPort
}
return 0
}
// SignalState contains the latest state of a signal connection
type SignalState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -2709,6 +2717,7 @@ type DebugBundleRequest struct {
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
CliVersion string `protobuf:"bytes,6,opt,name=cliVersion,proto3" json:"cliVersion,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2771,6 +2780,13 @@ func (x *DebugBundleRequest) GetLogFileCount() uint32 {
return 0
}
func (x *DebugBundleRequest) GetCliVersion() string {
if x != nil {
return x.CliVersion
}
return ""
}
type DebugBundleResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
@@ -6389,7 +6405,7 @@ const file_daemon_proto_rawDesc = "" +
"\n" +
"sshHostKey\x18\x13 \x01(\fR\n" +
"sshHostKey\x12\x12\n" +
"\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x84\x02\n" +
"\x04ipv6\x18\x14 \x01(\tR\x04ipv6\"\x9c\x02\n" +
"\x0eLocalPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" +
@@ -6398,7 +6414,8 @@ const file_daemon_proto_rawDesc = "" +
"\x10rosenpassEnabled\x18\x05 \x01(\bR\x10rosenpassEnabled\x120\n" +
"\x13rosenpassPermissive\x18\x06 \x01(\bR\x13rosenpassPermissive\x12\x1a\n" +
"\bnetworks\x18\a \x03(\tR\bnetworks\x12\x12\n" +
"\x04ipv6\x18\b \x01(\tR\x04ipv6\"S\n" +
"\x04ipv6\x18\b \x01(\tR\x04ipv6\x12\x16\n" +
"\x06wgPort\x18\t \x01(\x05R\x06wgPort\"S\n" +
"\vSignalState\x12\x10\n" +
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
@@ -6475,14 +6492,17 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xb4\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\x12\x1e\n" +
"\n" +
"cliVersion\x18\x06 \x01(\tR\n" +
"cliVersion\"}\n" +
"\x13DebugBundleResponse\x12\x12\n" +
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +

View File

@@ -349,6 +349,7 @@ message LocalPeerState {
bool rosenpassPermissive = 6;
repeated string networks = 7;
string ipv6 = 8;
int32 wgPort = 9;
}
// SignalState contains the latest state of a signal connection
@@ -471,6 +472,7 @@ message DebugBundleRequest {
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
string cliVersion = 6;
}
message DebugBundleResponse {

View File

@@ -1,17 +1,16 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
if ! which realpath >/dev/null 2>&1; then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
script_path=$(dirname "$(realpath "$0")")
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/version"
)
// DebugBundle creates a debug bundle and returns the location.
@@ -67,6 +68,8 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
CapturePath: capturePath,
RefreshStatus: refreshStatus,
ClientMetrics: clientMetrics,
DaemonVersion: version.NetbirdVersion(),
CliVersion: req.CliVersion,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),

View File

@@ -143,6 +143,7 @@ type OutputOverview struct {
IPv6 string `json:"netbirdIpv6,omitempty" yaml:"netbirdIpv6,omitempty"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
WgPort int `json:"wireguardPort" yaml:"wireguardPort"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
@@ -187,6 +188,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
IPv6: pbFullStatus.GetLocalPeerState().GetIpv6(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
WgPort: int(pbFullStatus.GetLocalPeerState().GetWgPort()),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
@@ -547,6 +549,21 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
daemonVersion := "N/A"
if o.DaemonVersion != "" {
daemonVersion = o.DaemonVersion
}
cliVersion := version.NetbirdVersion()
if o.CliVersion != "" {
cliVersion = o.CliVersion
}
wgPortString := "N/A"
if o.WgPort > 0 {
wgPortString = fmt.Sprintf("%d", o.WgPort)
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
@@ -560,6 +577,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"NetBird IP: %s\n"+
"%s"+
"Interface type: %s\n"+
"Wireguard port: %s\n"+
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
@@ -567,8 +585,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"%s"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
o.DaemonVersion,
version.NetbirdVersion(),
daemonVersion,
cliVersion,
o.ProfileName,
managementConnString,
signalConnString,
@@ -578,6 +596,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
interfaceIP,
ipv6Line,
interfaceTypeString,
wgPortString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,

View File

@@ -94,6 +94,7 @@ var resp = &proto.StatusResponse{
Ipv6: "fd00::100",
PubKey: "Some-Pub-Key",
KernelInterface: true,
WgPort: 51820,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
"10.10.0.0/24",
@@ -210,6 +211,7 @@ var overview = OutputOverview{
IPv6: "fd00::100",
PubKey: "Some-Pub-Key",
KernelInterface: true,
WgPort: 51820,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []NsServerGroupStateOutput{
{
@@ -369,6 +371,7 @@ func TestParsingToJSON(t *testing.T) {
"netbirdIpv6": "fd00::100",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"wireguardPort": 51820,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
@@ -487,6 +490,7 @@ netbirdIp: 192.168.178.100/16
netbirdIpv6: fd00::100
publicKey: Some-Pub-Key
usesKernelInterface: true
wireguardPort: 51820
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
@@ -579,12 +583,13 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
NetBird IPv6: fd00::100
Interface type: Kernel
Wireguard port: %d
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
assert.Equal(t, expectedDetail, detail)
}
@@ -604,6 +609,7 @@ FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
NetBird IPv6: fd00::100
Interface type: Kernel
Wireguard port: 51820
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled

View File

@@ -502,7 +502,7 @@ func (s *serviceClient) getConnectionForm() *widget.Form {
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
{Text: "Interface Name", Widget: s.iInterfaceName},
{Text: "Interface Port", Widget: s.iInterfacePort},
{Text: "Interface Port", Widget: s.iInterfacePort, HintText: "If set to 0, a random free port will be used"},
{Text: "MTU", Widget: s.iMTU},
{Text: "Log File", Widget: s.iLogFile},
},
@@ -558,8 +558,8 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
if err != nil {
return 0, 0, errors.New("invalid interface port")
}
if port < 1 || port > 65535 {
return 0, 0, errors.New("invalid interface port: out of range 1-65535")
if port < 0 || port > 65535 {
return 0, 0, errors.New("invalid interface port: out of range 0-65535")
}
var mtu int64
@@ -1438,7 +1438,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
}
config.WgIface = cfg.InterfaceName
if cfg.WireguardPort != 0 {
if cfg.WireguardPort >= 0 && cfg.WireguardPort <= 65535 {
config.WgPort = int(cfg.WireguardPort)
} else {
config.WgPort = iface.DefaultWgPort

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
uptypes "github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/version"
)
// Initial state for the debug collection
@@ -462,6 +463,7 @@ func (s *serviceClient) createDebugBundleFromCollection(
request := &proto.DebugBundleRequest{
Anonymize: params.anonymize,
SystemInfo: params.systemInfo,
CliVersion: version.NetbirdVersion(),
}
if params.upload {
@@ -593,6 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
request := &proto.DebugBundleRequest{
Anonymize: anonymize,
SystemInfo: systemInfo,
CliVersion: version.NetbirdVersion(),
}
if uploadURL != "" {

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.21.9
// protoc-gen-go v1.36.11
// protoc v7.34.1
// source: flow.proto
package proto
@@ -12,6 +12,7 @@ import (
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
@@ -125,27 +126,24 @@ func (Direction) EnumDescriptor() ([]byte, []int) {
}
type FlowEvent struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// Unique client event identifier
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
// When the event occurred
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
// Public key of the sending peer
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *FlowEvent) Reset() {
*x = FlowEvent{}
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_flow_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *FlowEvent) String() string {
@@ -156,7 +154,7 @@ func (*FlowEvent) ProtoMessage() {}
func (x *FlowEvent) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -207,22 +205,19 @@ func (x *FlowEvent) GetIsInitiator() bool {
}
type FlowEventAck struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// Unique client event identifier that has been ack'ed
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
IsInitiator bool `protobuf:"varint,2,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
IsInitiator bool `protobuf:"varint,2,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *FlowEventAck) Reset() {
*x = FlowEventAck{}
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_flow_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *FlowEventAck) String() string {
@@ -233,7 +228,7 @@ func (*FlowEventAck) ProtoMessage() {}
func (x *FlowEventAck) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -263,10 +258,7 @@ func (x *FlowEventAck) GetIsInitiator() bool {
}
type FlowFields struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
// Unique client flow session identifier
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
// Flow type
@@ -283,7 +275,7 @@ type FlowFields struct {
DestIp []byte `protobuf:"bytes,7,opt,name=dest_ip,json=destIp,proto3" json:"dest_ip,omitempty"`
// Layer 4 -specific information
//
// Types that are assignable to ConnectionInfo:
// Types that are valid to be assigned to ConnectionInfo:
//
// *FlowFields_PortInfo
// *FlowFields_IcmpInfo
@@ -297,15 +289,18 @@ type FlowFields struct {
// Resource ID
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
NumOfStarts uint64 `protobuf:"varint,16,opt,name=num_of_starts,json=numOfStarts,proto3" json:"num_of_starts,omitempty"`
NumOfEnds uint64 `protobuf:"varint,17,opt,name=num_of_ends,json=numOfEnds,proto3" json:"num_of_ends,omitempty"`
NumOfDrops uint64 `protobuf:"varint,18,opt,name=num_of_drops,json=numOfDrops,proto3" json:"num_of_drops,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *FlowFields) Reset() {
*x = FlowFields{}
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_flow_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *FlowFields) String() string {
@@ -316,7 +311,7 @@ func (*FlowFields) ProtoMessage() {}
func (x *FlowFields) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -380,23 +375,27 @@ func (x *FlowFields) GetDestIp() []byte {
return nil
}
func (m *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
if m != nil {
return m.ConnectionInfo
func (x *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
if x != nil {
return x.ConnectionInfo
}
return nil
}
func (x *FlowFields) GetPortInfo() *PortInfo {
if x, ok := x.GetConnectionInfo().(*FlowFields_PortInfo); ok {
return x.PortInfo
if x != nil {
if x, ok := x.ConnectionInfo.(*FlowFields_PortInfo); ok {
return x.PortInfo
}
}
return nil
}
func (x *FlowFields) GetIcmpInfo() *ICMPInfo {
if x, ok := x.GetConnectionInfo().(*FlowFields_IcmpInfo); ok {
return x.IcmpInfo
if x != nil {
if x, ok := x.ConnectionInfo.(*FlowFields_IcmpInfo); ok {
return x.IcmpInfo
}
}
return nil
}
@@ -443,6 +442,27 @@ func (x *FlowFields) GetDestResourceId() []byte {
return nil
}
func (x *FlowFields) GetNumOfStarts() uint64 {
if x != nil {
return x.NumOfStarts
}
return 0
}
func (x *FlowFields) GetNumOfEnds() uint64 {
if x != nil {
return x.NumOfEnds
}
return 0
}
func (x *FlowFields) GetNumOfDrops() uint64 {
if x != nil {
return x.NumOfDrops
}
return 0
}
type isFlowFields_ConnectionInfo interface {
isFlowFields_ConnectionInfo()
}
@@ -463,21 +483,18 @@ func (*FlowFields_IcmpInfo) isFlowFields_ConnectionInfo() {}
// TCP/UDP port information
type PortInfo struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
unknownFields protoimpl.UnknownFields
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *PortInfo) Reset() {
*x = PortInfo{}
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_flow_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *PortInfo) String() string {
@@ -488,7 +505,7 @@ func (*PortInfo) ProtoMessage() {}
func (x *PortInfo) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -519,21 +536,18 @@ func (x *PortInfo) GetDestPort() uint32 {
// ICMP message information
type ICMPInfo struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
unknownFields protoimpl.UnknownFields
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *ICMPInfo) Reset() {
*x = ICMPInfo{}
if protoimpl.UnsafeEnabled {
mi := &file_flow_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_flow_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ICMPInfo) String() string {
@@ -544,7 +558,7 @@ func (*ICMPInfo) ProtoMessage() {}
func (x *ICMPInfo) ProtoReflect() protoreflect.Message {
mi := &file_flow_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -575,102 +589,83 @@ func (x *ICMPInfo) GetIcmpCode() uint32 {
var File_flow_proto protoreflect.FileDescriptor
var file_flow_proto_rawDesc = []byte{
0x0a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x66, 0x6c,
0x6f, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0xd4, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75,
0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d,
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63,
0x5f, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c,
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x31, 0x0a, 0x0b, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x66, 0x69,
0x65, 0x6c, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x66, 0x6c, 0x6f,
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e,
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69,
0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x4b, 0x0a, 0x0c, 0x46, 0x6c,
0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76,
0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76,
0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69,
0x61, 0x74, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x49, 0x6e,
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77,
0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12,
0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e,
0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12,
0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x06, 0x72, 0x75, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65,
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66, 0x6c,
0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64, 0x69,
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70,
0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70,
0x12, 0x17, 0x0a, 0x07, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28,
0x0c, 0x52, 0x06, 0x64, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f, 0x72,
0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66,
0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08,
0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c,
0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x69,
0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70, 0x61,
0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78, 0x50,
0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61, 0x63,
0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50, 0x61,
0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65,
0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73,
0x12, 0x19, 0x0a, 0x08, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01,
0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73,
0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69,
0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52,
0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73,
0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20,
0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63,
0x65, 0x49, 0x64, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f,
0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e,
0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72,
0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50,
0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74,
0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74,
0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09,
0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52,
0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d,
0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63,
0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10,
0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00,
0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01,
0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d,
0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a,
0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49,
0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a,
0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c,
0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65,
0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45,
0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77,
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08,
0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
const file_flow_proto_rawDesc = "" +
"\n" +
"\n" +
"flow.proto\x12\x04flow\x1a\x1fgoogle/protobuf/timestamp.proto\"\xd4\x01\n" +
"\tFlowEvent\x12\x19\n" +
"\bevent_id\x18\x01 \x01(\fR\aeventId\x128\n" +
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n" +
"\n" +
"public_key\x18\x03 \x01(\fR\tpublicKey\x121\n" +
"\vflow_fields\x18\x04 \x01(\v2\x10.flow.FlowFieldsR\n" +
"flowFields\x12 \n" +
"\visInitiator\x18\x05 \x01(\bR\visInitiator\"K\n" +
"\fFlowEventAck\x12\x19\n" +
"\bevent_id\x18\x01 \x01(\fR\aeventId\x12 \n" +
"\visInitiator\x18\x02 \x01(\bR\visInitiator\"\x82\x05\n" +
"\n" +
"FlowFields\x12\x17\n" +
"\aflow_id\x18\x01 \x01(\fR\x06flowId\x12\x1e\n" +
"\x04type\x18\x02 \x01(\x0e2\n" +
".flow.TypeR\x04type\x12\x17\n" +
"\arule_id\x18\x03 \x01(\fR\x06ruleId\x12-\n" +
"\tdirection\x18\x04 \x01(\x0e2\x0f.flow.DirectionR\tdirection\x12\x1a\n" +
"\bprotocol\x18\x05 \x01(\rR\bprotocol\x12\x1b\n" +
"\tsource_ip\x18\x06 \x01(\fR\bsourceIp\x12\x17\n" +
"\adest_ip\x18\a \x01(\fR\x06destIp\x12-\n" +
"\tport_info\x18\b \x01(\v2\x0e.flow.PortInfoH\x00R\bportInfo\x12-\n" +
"\ticmp_info\x18\t \x01(\v2\x0e.flow.ICMPInfoH\x00R\bicmpInfo\x12\x1d\n" +
"\n" +
"rx_packets\x18\n" +
" \x01(\x04R\trxPackets\x12\x1d\n" +
"\n" +
"tx_packets\x18\v \x01(\x04R\ttxPackets\x12\x19\n" +
"\brx_bytes\x18\f \x01(\x04R\arxBytes\x12\x19\n" +
"\btx_bytes\x18\r \x01(\x04R\atxBytes\x12,\n" +
"\x12source_resource_id\x18\x0e \x01(\fR\x10sourceResourceId\x12(\n" +
"\x10dest_resource_id\x18\x0f \x01(\fR\x0edestResourceId\x12\"\n" +
"\rnum_of_starts\x18\x10 \x01(\x04R\vnumOfStarts\x12\x1e\n" +
"\vnum_of_ends\x18\x11 \x01(\x04R\tnumOfEnds\x12 \n" +
"\fnum_of_drops\x18\x12 \x01(\x04R\n" +
"numOfDropsB\x11\n" +
"\x0fconnection_info\"H\n" +
"\bPortInfo\x12\x1f\n" +
"\vsource_port\x18\x01 \x01(\rR\n" +
"sourcePort\x12\x1b\n" +
"\tdest_port\x18\x02 \x01(\rR\bdestPort\"D\n" +
"\bICMPInfo\x12\x1b\n" +
"\ticmp_type\x18\x01 \x01(\rR\bicmpType\x12\x1b\n" +
"\ticmp_code\x18\x02 \x01(\rR\bicmpCode*E\n" +
"\x04Type\x12\x10\n" +
"\fTYPE_UNKNOWN\x10\x00\x12\x0e\n" +
"\n" +
"TYPE_START\x10\x01\x12\f\n" +
"\bTYPE_END\x10\x02\x12\r\n" +
"\tTYPE_DROP\x10\x03*;\n" +
"\tDirection\x12\x15\n" +
"\x11DIRECTION_UNKNOWN\x10\x00\x12\v\n" +
"\aINGRESS\x10\x01\x12\n" +
"\n" +
"\x06EGRESS\x10\x022B\n" +
"\vFlowService\x123\n" +
"\x06Events\x12\x0f.flow.FlowEvent\x1a\x12.flow.FlowEventAck\"\x00(\x010\x01B\bZ\x06/protob\x06proto3"
var (
file_flow_proto_rawDescOnce sync.Once
file_flow_proto_rawDescData = file_flow_proto_rawDesc
file_flow_proto_rawDescData []byte
)
func file_flow_proto_rawDescGZIP() []byte {
file_flow_proto_rawDescOnce.Do(func() {
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(file_flow_proto_rawDescData)
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_flow_proto_rawDesc), len(file_flow_proto_rawDesc)))
})
return file_flow_proto_rawDescData
}
var file_flow_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
var file_flow_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_flow_proto_goTypes = []interface{}{
var file_flow_proto_goTypes = []any{
(Type)(0), // 0: flow.Type
(Direction)(0), // 1: flow.Direction
(*FlowEvent)(nil), // 2: flow.FlowEvent
@@ -701,69 +696,7 @@ func file_flow_proto_init() {
if File_flow_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_flow_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FlowEvent); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FlowEventAck); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*FlowFields); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PortInfo); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_flow_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ICMPInfo); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_flow_proto_msgTypes[2].OneofWrappers = []interface{}{
file_flow_proto_msgTypes[2].OneofWrappers = []any{
(*FlowFields_PortInfo)(nil),
(*FlowFields_IcmpInfo)(nil),
}
@@ -771,7 +704,7 @@ func file_flow_proto_init() {
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_flow_proto_rawDesc,
RawDescriptor: unsafe.Slice(unsafe.StringData(file_flow_proto_rawDesc), len(file_flow_proto_rawDesc)),
NumEnums: 2,
NumMessages: 5,
NumExtensions: 0,
@@ -783,7 +716,6 @@ func file_flow_proto_init() {
MessageInfos: file_flow_proto_msgTypes,
}.Build()
File_flow_proto = out.File
file_flow_proto_rawDesc = nil
file_flow_proto_goTypes = nil
file_flow_proto_depIdxs = nil
}

View File

@@ -75,6 +75,9 @@ message FlowFields {
bytes source_resource_id = 14;
bytes dest_resource_id = 15;
uint64 num_of_starts = 16;
uint64 num_of_ends = 17;
uint64 num_of_drops = 18;
}
// Flow event types

View File

@@ -1,4 +1,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.1
// - protoc v7.34.1
// source: flow.proto
package proto
@@ -11,15 +15,19 @@ import (
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
FlowService_Events_FullMethodName = "/flow.FlowService/Events"
)
// FlowServiceClient is the client API for FlowService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type FlowServiceClient interface {
// Client to receiver streams of events and acknowledgements
Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error)
Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error)
}
type flowServiceClient struct {
@@ -30,54 +38,40 @@ func NewFlowServiceClient(cc grpc.ClientConnInterface) FlowServiceClient {
return &flowServiceClient{cc}
}
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error) {
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], "/flow.FlowService/Events", opts...)
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], FlowService_Events_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &flowServiceEventsClient{stream}
x := &grpc.GenericClientStream[FlowEvent, FlowEventAck]{ClientStream: stream}
return x, nil
}
type FlowService_EventsClient interface {
Send(*FlowEvent) error
Recv() (*FlowEventAck, error)
grpc.ClientStream
}
type flowServiceEventsClient struct {
grpc.ClientStream
}
func (x *flowServiceEventsClient) Send(m *FlowEvent) error {
return x.ClientStream.SendMsg(m)
}
func (x *flowServiceEventsClient) Recv() (*FlowEventAck, error) {
m := new(FlowEventAck)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type FlowService_EventsClient = grpc.BidiStreamingClient[FlowEvent, FlowEventAck]
// FlowServiceServer is the server API for FlowService service.
// All implementations must embed UnimplementedFlowServiceServer
// for forward compatibility
// for forward compatibility.
type FlowServiceServer interface {
// Client to receiver streams of events and acknowledgements
Events(FlowService_EventsServer) error
Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error
mustEmbedUnimplementedFlowServiceServer()
}
// UnimplementedFlowServiceServer must be embedded to have forward compatible implementations.
type UnimplementedFlowServiceServer struct {
}
// UnimplementedFlowServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedFlowServiceServer struct{}
func (UnimplementedFlowServiceServer) Events(FlowService_EventsServer) error {
return status.Errorf(codes.Unimplemented, "method Events not implemented")
func (UnimplementedFlowServiceServer) Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error {
return status.Error(codes.Unimplemented, "method Events not implemented")
}
func (UnimplementedFlowServiceServer) mustEmbedUnimplementedFlowServiceServer() {}
func (UnimplementedFlowServiceServer) testEmbeddedByValue() {}
// UnsafeFlowServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to FlowServiceServer will
@@ -87,34 +81,22 @@ type UnsafeFlowServiceServer interface {
}
func RegisterFlowServiceServer(s grpc.ServiceRegistrar, srv FlowServiceServer) {
// If the following call panics, it indicates UnimplementedFlowServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&FlowService_ServiceDesc, srv)
}
func _FlowService_Events_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(FlowServiceServer).Events(&flowServiceEventsServer{stream})
return srv.(FlowServiceServer).Events(&grpc.GenericServerStream[FlowEvent, FlowEventAck]{ServerStream: stream})
}
type FlowService_EventsServer interface {
Send(*FlowEventAck) error
Recv() (*FlowEvent, error)
grpc.ServerStream
}
type flowServiceEventsServer struct {
grpc.ServerStream
}
func (x *flowServiceEventsServer) Send(m *FlowEventAck) error {
return x.ServerStream.SendMsg(m)
}
func (x *flowServiceEventsServer) Recv() (*FlowEvent, error) {
m := new(FlowEvent)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type FlowService_EventsServer = grpc.BidiStreamingServer[FlowEvent, FlowEventAck]
// FlowService_ServiceDesc is the grpc.ServiceDesc for FlowService service.
// It's only intended for direct use with grpc.RegisterService,

View File

@@ -10,8 +10,9 @@ fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
echo "$script_path"
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
#go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
#go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./flow.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

View File

@@ -99,6 +99,9 @@ func addFields(entry *logrus.Entry) {
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
entry.Data[context.UserAgentKey] = ctxUserAgent
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}

8
go.mod
View File

@@ -2,6 +2,8 @@ module github.com/netbirdio/netbird
go 1.25.5
toolchain go1.25.11
require (
cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0
@@ -24,13 +26,13 @@ require (
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
github.com/DeRuina/timberjack v1.4.2
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/config v1.31.6
@@ -54,6 +56,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
@@ -211,10 +214,9 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect

8
go.sum
View File

@@ -29,6 +29,8 @@ github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/DeRuina/timberjack v1.4.2 h1:4bKlzhKdsR+2oNkgef9mqb4n11ICow8VK88RfzJPzN8=
github.com/DeRuina/timberjack v1.4.2/go.mod h1:RLoeQrwrCGIEF8gO5nV5b/gMD0QIy7bzQhBUgpp1EqE=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0=
@@ -273,8 +275,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
@@ -940,8 +942,6 @@ gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8
gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=

View File

@@ -19,6 +19,46 @@ readonly MSG_SEPARATOR="=========================================="
# Utility Functions
############################################
check_docker_sock_perms() {
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
sock="${sock#unix://}"
if [[ ! -S "$sock" ]]; then
return 0
fi
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
local group
if [[ "${OSTYPE}" == "darwin"* ]]; then
group="$(stat -f '%Sg' "$sock")"
else
group="$(stat -c '%G' "$sock")"
fi
echo "Cannot access Docker socket: $sock" > /dev/stderr
echo "" > /dev/stderr
echo "Socket permissions:" > /dev/stderr
ls -l "$sock" > /dev/stderr
echo "" > /dev/stderr
if [[ "$group" == "docker" ]]; then
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
echo " newgrp $group" > /dev/stderr
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
else
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
fi
exit 1
fi
return 0
}
check_docker_compose() {
if command -v docker-compose &> /dev/null
then
@@ -311,11 +351,12 @@ initialize_default_values() {
NETBIRD_STUN_PORT=3478
# Docker images
DASHBOARD_IMAGE="netbirdio/dashboard:latest"
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
# Combined server replaces separate signal, relay, and management containers
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:latest"
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:latest"
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
NETBIRD_PROXY_IMAGE=${NETBIRD_PROXY_IMAGE:-"netbirdio/reverse-proxy:latest"}
TRAEFIK_IMAGE=${TRAEFIK_IMAGE:-"traefik:v3.6"}
CROWDSEC_IMAGE=${CROWDSEC_IMAGE:-"crowdsecurity/crowdsec:v1.7.7"}
# Reverse proxy configuration
REVERSE_PROXY_TYPE="0"
TRAEFIK_EXTERNAL_NETWORK=""
@@ -580,12 +621,15 @@ start_services_and_show_instructions() {
}
init_environment() {
# Check if docker compose is installed using check_docker_compose function
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
check_docker_sock_perms
initialize_default_values
configure_domain
configure_reverse_proxy
check_jq
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
check_existing_installation
generate_configuration_files
@@ -656,7 +700,7 @@ render_docker_compose_traefik_builtin() {
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
crowdsec_service="
crowdsec:
image: crowdsecurity/crowdsec:v1.7.7
image: $CROWDSEC_IMAGE
container_name: netbird-crowdsec
restart: unless-stopped
networks: [netbird]
@@ -687,7 +731,7 @@ render_docker_compose_traefik_builtin() {
services:
# Traefik reverse proxy (automatic TLS via Let's Encrypt)
traefik:
image: traefik:v3.6
image: $TRAEFIK_IMAGE
container_name: netbird-traefik
restart: unless-stopped
networks:
@@ -771,7 +815,7 @@ $traefik_dynamic_volume
labels:
- traefik.enable=true
# gRPC router (needs h2c backend for HTTP/2 cleartext)
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`))
- traefik.http.routers.netbird-grpc.rule=Host(\`$NETBIRD_DOMAIN\`) && (PathPrefix(\`/signalexchange.SignalExchange/\`) || PathPrefix(\`/management.ManagementService/\`) || PathPrefix(\`/management.ProxyService/\`))
- traefik.http.routers.netbird-grpc.entrypoints=websecure
- traefik.http.routers.netbird-grpc.tls=true
- traefik.http.routers.netbird-grpc.tls.certresolver=letsencrypt

View File

@@ -32,6 +32,7 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
type Controller struct {
@@ -514,7 +515,7 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
if version.IsDevelopmentVersion(peer.Meta.WtVersion) {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)

View File

@@ -488,6 +488,195 @@ func TestUpdate_AllowsPortChange(t *testing.T) {
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestUpdate_PreservesPortWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected by the custom-port capability check")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved on unsupported cluster")
}
func TestUpdate_PreservesPortWhenCustomPortsUnknown(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, nil)
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "update must not be rejected when cluster capability is unknown")
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when capability is unknown")
}
func TestUpdate_RejectsPortChangeWhenCustomPortsNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "explicit port change on update must be rejected on unsupported clusters")
assert.Contains(t, err.Error(), "custom ports not supported on target cluster")
}
func TestUpdate_TLSPortChangeAllowedWhenNotSupported(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tls-svc", "tls", "app.example.com", testCluster, 443)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tls-svc",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err, "TLS port change uses SNI routing and is exempt from the custom-port check")
assert.Equal(t, uint16(9999), updated.ListenPort, "TLS port change should be applied")
}
func TestValidateL4PortDiffOnClusterDiff(t *testing.T) {
tests := []struct {
name string
mode string
customPorts *bool
newPort uint16
oldPort uint16
wantErr bool
}{
{"tcp port change unsupported", "tcp", boolPtr(false), 54321, 12345, true},
{"tcp port change unknown capability", "tcp", nil, 54321, 12345, true},
{"udp port change unsupported", "udp", boolPtr(false), 54321, 12345, true},
{"tcp first port assignment unsupported", "tcp", boolPtr(false), 54321, 0, true},
{"tcp port change supported", "tcp", boolPtr(true), 54321, 12345, false},
{"tcp port unchanged unsupported", "tcp", boolPtr(false), 12345, 12345, false},
{"tcp zero port unsupported", "tcp", boolPtr(false), 0, 12345, false},
{"tls port change unsupported", "tls", boolPtr(false), 9999, 443, false},
{"http mode ignored", "http", boolPtr(false), 54321, 12345, false},
{"empty mode ignored", "", boolPtr(false), 54321, 12345, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
newSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.newPort, ProxyCluster: testCluster}
oldSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.oldPort, ProxyCluster: testCluster}
err := validateL4PortDiffOnClusterDiff(tc.customPorts, newSvc, oldSvc)
if tc.wantErr {
assert.Error(t, err, "port diff should be rejected for %s", tc.name)
} else {
assert.NoError(t, err, "port diff should be allowed for %s", tc.name)
}
})
}
}
func TestUpdate_PortConflictRejected(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "tcp-a", "tcp", "tcp-a."+testCluster, testCluster, 5432)
svcB := seedService(t, testStore, "tcp-b", "tcp", "tcp-b."+testCluster, testCluster, 6543)
updated := &rpservice.Service{
ID: svcB.ID,
AccountID: testAccountID,
Name: "tcp-b",
Mode: "tcp",
Domain: "tcp-b." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.Error(t, err, "updating to a port held by another service should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestUpdate_AutoAssignsWhenNoPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 0)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.True(t, updated.ListenPort >= autoAssignPortMin && updated.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", updated.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, updated.PortAutoAssigned, "PortAutoAssigned should be set when update triggers auto-assignment")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()

View File

@@ -338,7 +338,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
}
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
return err
}
@@ -367,11 +367,11 @@ func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service)
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool, serviceUpdate bool) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && !serviceUpdate && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
@@ -465,7 +465,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
return err
}
@@ -651,12 +651,22 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
// if the service is being updated, and we decide in the future to allow mode update,
// we should reconsider the currently assigned port if not 0 for clusters that don't support custom ports
if err := validateL4PortDiffOnClusterDiff(customPorts, service, existingService); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, service, customPorts, true); err != nil {
return err
}
// we can try carrying the previous service port into a new cluster, if this becomes a problem for multiple users,
// we should reconsider adding another check
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -664,6 +674,21 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
// validateL4PortDiffOnClusterDiff checks if custom L4 ports are configured and validates port changes across clusters.
// It ensures no port changes if custom ports are unsupported for a given cluster and protocol mode.
// Returns an error if validation fails, otherwise returns nil.
func validateL4PortDiffOnClusterDiff(customPorts *bool, newSVC, oldSVC *service.Service) error {
if !service.IsPortBasedProtocol(newSVC.Mode) || (customPorts != nil && *customPorts) {
return nil
}
if newSVC.ListenPort != 0 && newSVC.ListenPort != oldSVC.ListenPort {
return status.Errorf(status.InvalidArgument, "custom ports not supported on target cluster %s", newSVC.ProxyCluster)
}
return nil
}
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks

View File

@@ -932,7 +932,11 @@ func (s *Service) validateL4Target(target *Target) error {
if target.TargetId == "" {
return errors.New("target_id is required for L4 services")
}
if target.TargetType != TargetTypeCluster && target.Port == 0 {
// Cluster targets resolve their upstream host:port from the target's
// own Host/Port fields just like the other L4 types — buildPathMappings
// emits net.JoinHostPort(target.Host, target.Port) for every L4
// target, so allowing port=0 here would let ":0" reach the proxy.
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
switch target.TargetType {

View File

@@ -1176,7 +1176,12 @@ func TestValidate_HTTPClusterTarget_RequiresDirectUpstream(t *testing.T) {
assert.ErrorContains(t, rp.Validate(), "direct upstream disabled", "cluster target must reject direct_upstream=false")
}
func TestValidate_L4ClusterTarget(t *testing.T) {
// TestValidate_L4ClusterTarget_RequiresPort confirms that an L4 cluster
// target without an explicit port is rejected. buildPathMappings emits
// net.JoinHostPort(target.Host, target.Port) for every L4 target — so
// allowing port=0 would let the proxy ship ":0" upstreams. The port
// requirement is the same as every other L4 target type.
func TestValidate_L4ClusterTarget_RequiresPort(t *testing.T) {
rp := validProxy()
rp.Mode = ModeTCP
rp.ListenPort = 9000
@@ -1186,7 +1191,12 @@ func TestValidate_L4ClusterTarget(t *testing.T) {
Protocol: "tcp",
Enabled: true,
}}
require.NoError(t, rp.Validate(), "L4 cluster target must validate without an explicit port")
assert.ErrorContains(t, rp.Validate(), "port is required",
"L4 cluster target must require an explicit port like other L4 target types")
rp.Targets[0].Port = 5432
rp.Targets[0].Host = "db.lan"
require.NoError(t, rp.Validate(), "L4 cluster target with host:port must validate")
}
func TestService_Copy_RoundtripsPrivate(t *testing.T) {

View File

@@ -122,7 +122,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.errCh = make(chan error, 4)
if s.autoResolveDomains {
s.resolveDomains(srvCtx)
s.ResolveDomains(srvCtx)
}
s.PeersManager()
@@ -398,10 +398,10 @@ func (s *BaseServer) serveGRPCWithHTTP(ctx context.Context, listener net.Listene
}()
}
// resolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// ResolveDomains determines dnsDomain and mgmtSingleAccModeDomain based on store state.
// Fresh installs use the default self-hosted domain, while existing installs reuse the
// persisted account domain to keep addressing stable across config changes.
func (s *BaseServer) resolveDomains(ctx context.Context) {
func (s *BaseServer) ResolveDomains(ctx context.Context) {
st := s.Store()
setDefault := func(logMsg string, args ...any) {

View File

@@ -22,7 +22,7 @@ func TestResolveDomains_FreshInstallUsesDefault(t *testing.T) {
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
srv.ResolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)
@@ -40,7 +40,7 @@ func TestResolveDomains_ExistingInstallUsesPersistedDomain(t *testing.T) {
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
srv.ResolveDomains(context.Background())
require.Equal(t, "vpn.mycompany.com", srv.dnsDomain)
require.Equal(t, "vpn.mycompany.com", srv.mgmtSingleAccModeDomain)
@@ -56,7 +56,7 @@ func TestResolveDomains_StoreErrorFallsBackToDefault(t *testing.T) {
srv := NewServer(&Config{NbConfig: &nbconfig.Config{}})
Inject[store.Store](srv, mockStore)
srv.resolveDomains(context.Background())
srv.ResolveDomains(context.Background())
require.Equal(t, DefaultSelfHostedDomain, srv.dnsDomain)
require.Equal(t, DefaultSelfHostedDomain, srv.mgmtSingleAccModeDomain)

View File

@@ -8,6 +8,8 @@ import (
"strings"
"time"
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -28,6 +30,23 @@ import (
"github.com/netbirdio/netbird/shared/sshauth"
)
const (
// deprecatedRemotePeersVersion is the version of Netbird that introduced the NetworkMap.RemotePeers field, deprecated in favor of RemotePeers.
deprecatedRemotePeersVersion = "0.29.3"
)
// precomputedDeprecatedRemotePeersConstraint is the parsed ">= 0.29.3" constraint,
// built once at init since the bound is a compile-time constant.
var precomputedDeprecatedRemotePeersConstraint version.Constraints
func init() {
constraint, err := version.NewConstraint(">= " + deprecatedRemotePeersVersion)
if err != nil {
panic("parse deprecated remote peers version constraint: " + err.Error())
}
precomputedDeprecatedRemotePeersConstraint = constraint
}
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
@@ -155,7 +174,11 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
response.RemotePeers = remotePeers
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
response.RemotePeers = remotePeers
}
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
@@ -246,6 +269,19 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m
return hashedUsers, machineUsers
}
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
if nbversion.IsDevelopmentVersion(peerVersion) {
return true
}
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil {
return false
}
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
@@ -363,7 +399,6 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {

View File

@@ -202,6 +202,42 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
}
}
// TestShouldSkipSendingDeprecatedRemotePeers covers the version gate that
// stops populating the deprecated top-level SyncResponse.RemotePeers field for
// peers new enough to read RemotePeers off the NetworkMap. Development builds
// are treated as latest and skip the field. The gate otherwise fails safe: a
// release version older than the boundary, or one that can't be parsed (empty,
// garbage, prereleases of the boundary) still receives the deprecated field so
// older/unknown clients keep working.
func TestShouldSkipSendingDeprecatedRemotePeers(t *testing.T) {
tests := []struct {
name string
peerVersion string
wantSkip bool
}{
{"exact boundary skips", "0.29.3", true},
{"newer patch skips", "0.29.4", true},
{"newer minor skips", "0.30.0", true},
{"newer major skips", "1.0.0", true},
{"v-prefixed newer skips", "v0.30.0", true},
{"development build skips", "development", true},
{"development build with commit skips", "development-abc123def456-dirty", true},
{"older patch keeps field", "0.29.2", false},
{"older minor keeps field", "0.28.0", false},
{"prerelease of boundary keeps field", "0.29.3-SNAPSHOT", false},
{"tagged dev prerelease keeps field", "v0.31.1-dev", false},
{"empty version keeps field", "", false},
{"garbage version keeps field", "not-a-version", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := shouldSkipSendingDeprecatedRemotePeers(tc.peerVersion)
assert.Equal(t, tc.wantSkip, got, "skip decision for peer version %q", tc.peerVersion)
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//

View File

@@ -666,8 +666,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
case resp := <-conn.sendChan:
if err := conn.sendResponse(resp); err != nil {
errChan <- err
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
return
}
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
case <-conn.ctx.Done():
return
}
@@ -978,6 +980,7 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
Private: m.Private,
}
}

View File

@@ -0,0 +1,88 @@
package grpc
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
// copy from the source, since callers assign it individually after cloning.
const authTokenField = "AuthToken"
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
// ProxyMapping with a non-zero value and verifies the clone carries each one
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
// without updating shallowCloneMapping fails this test.
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
src := &proto.ProxyMapping{}
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
clone := shallowCloneMapping(src)
require.NotNil(t, clone, "clone must not be nil")
srcVal := reflect.ValueOf(src).Elem()
cloneVal := reflect.ValueOf(clone).Elem()
for _, name := range populated {
srcField := srcVal.FieldByName(name).Interface()
cloneField := cloneVal.FieldByName(name).Interface()
if name == authTokenField {
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
continue
}
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
}
}
// populateExportedFields sets a non-zero value on every settable exported field
// of the struct and returns their names.
func populateExportedFields(t *testing.T, v reflect.Value) []string {
t.Helper()
var names []string
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
structField := typ.Field(i)
if structField.PkgPath != "" || !field.CanSet() {
continue
}
setNonZero(t, field, structField.Name)
names = append(names, structField.Name)
}
return names
}
// setNonZero assigns a deterministic non-zero value based on the field kind.
func setNonZero(t *testing.T, field reflect.Value, name string) {
t.Helper()
switch field.Kind() {
case reflect.String:
field.SetString("non-zero-" + name)
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Ptr:
field.Set(reflect.New(field.Type().Elem()))
case reflect.Slice:
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
case reflect.Map:
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
default:
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
}
}

View File

@@ -102,7 +102,7 @@ func generateSessionKeyPair(t *testing.T) (string, string) {
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, nil, time.Hour)
token, err := sessionkey.SignToken(privKeyB64, userID, "", domain, auth.MethodOIDC, nil, nil, time.Hour)
require.NoError(t, err)
return token
}
@@ -394,6 +394,10 @@ func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Cont
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsPrivate(_ context.Context, _ string) *bool {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
@@ -401,3 +405,24 @@ type testValidateSessionUsersManager struct {
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}
func (m *testValidateSessionUsersManager) GetUserWithGroups(ctx context.Context, userID string) (*types.User, []*types.Group, error) {
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, nil, err
}
if len(user.AutoGroups) == 0 {
return user, nil, nil
}
groupsMap, err := m.store.GetGroupsByIDs(ctx, store.LockingStrengthNone, user.AccountID, user.AutoGroups)
if err != nil {
return nil, nil, err
}
groups := make([]*types.Group, 0, len(user.AutoGroups))
for _, id := range user.AutoGroups {
if g, ok := groupsMap[id]; ok && g != nil {
groups = append(groups, g)
}
}
return user, groups, nil
}

View File

@@ -12,6 +12,7 @@ const (
RoleKey = nbcontext.RoleKey
UserIDKey = nbcontext.UserIDKey
PeerIDKey = nbcontext.PeerIDKey
UserAgentKey = nbcontext.UserAgentKey
)
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.

View File

@@ -30,6 +30,7 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const remoteJobsMinVer = "0.64.0"
@@ -372,7 +373,7 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
}
meetMinVer, err := posture.MeetsMinVersion(remoteJobsMinVer, p.Meta.WtVersion)
if !strings.Contains(p.Meta.WtVersion, "dev") && (!meetMinVer || err != nil) {
if !version.IsDevelopmentVersion(p.Meta.WtVersion) && (!meetMinVer || err != nil) {
return status.Errorf(status.PreconditionFailed, "peer version %s does not meet the minimum required version %s for remote jobs", p.Meta.WtVersion, remoteJobsMinVer)
}

View File

@@ -1216,6 +1216,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("NetworkResources").
Preload("Onboarding").
Preload("Services.Targets").
Preload("Domains").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -1302,7 +1303,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
var wg sync.WaitGroup
errChan := make(chan error, 12)
errChan := make(chan error, 16)
wg.Add(1)
go func() {
@@ -1403,6 +1404,17 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Services = services
}()
wg.Add(1)
go func() {
defer wg.Done()
domains, err := s.ListCustomDomains(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Domains = domains
}()
wg.Add(1)
go func() {
defer wg.Done()
@@ -4734,7 +4746,13 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
result := tx.
Take(&peer, fmt.Sprintf("account_id = ? AND %s = ?", column), accountID, jsonValue)
if result.Error != nil {
// no logging here
// A tunnel-IP miss is an expected outcome (e.g. the proxy's
// ValidateTunnelPeer probing an address that isn't in the
// account roster); surface it as NotFound so callers can tell
// it apart from a real store failure.
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer with ip %s not found", ip.String())
}
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
@@ -5962,6 +5980,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
}
err := s.db.
WithContext(ctx).
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").

View File

@@ -4,6 +4,8 @@ import (
"context"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
@@ -21,6 +23,63 @@ import (
"github.com/netbirdio/netbird/route"
)
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
// private service's DNS zone; without the preload the relation is empty and the
// service is silently skipped, so a custom domain never resolves on clients.
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
defer cleanup()
assertGetAccountLoadsCustomDomains(t, store)
}
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
assertGetAccountLoadsCustomDomains(t, store)
}
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
// paths: it persists two custom domains and asserts the relation comes back
// populated, which SynthesizePrivateServiceZones relies on.
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
accountID := "acct-custom-domains"
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
require.NoError(t, err, "creating the first custom domain must succeed")
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
require.NoError(t, err, "creating the second custom domain must succeed")
account, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
byDomain := map[string]string{}
for _, d := range account.Domains {
require.NotNil(t, d)
byDomain[d.Domain] = d.TargetCluster
}
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
}
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
// all fields and nested objects from the database, including deeply nested structures.
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {

View File

@@ -13,7 +13,7 @@ import (
)
func TestSqlStore_GetAccount_PrivateServiceRoundtrip(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
if os.Getenv("CI") == "true" && (runtime.GOOS == "darwin" || runtime.GOOS == "windows") {
t.Skip("skip CI tests on darwin and windows")
}

View File

@@ -491,6 +491,27 @@ func Test_GetAccount(t *testing.T) {
})
}
// TestSqlStore_GetPeerByIP_NotFound pins the not-found semantics the
// proxy's ValidateTunnelPeer relies on: a tunnel-IP that isn't in the
// account roster must surface as a NotFound error (not a generic
// Internal) so callers can distinguish an expected miss from a real
// store failure. A known IP still resolves.
func TestSqlStore_GetPeerByIP_NotFound(t *testing.T) {
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
const accountID = "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peer, err := store.GetPeerByIP(context.Background(), LockingStrengthNone, accountID, net.ParseIP("192.168.0.0"))
require.NoError(t, err, "known tunnel IP must resolve")
require.NotNil(t, peer)
_, err = store.GetPeerByIP(context.Background(), LockingStrengthNone, accountID, net.ParseIP("100.65.0.99"))
require.Error(t, err, "unknown tunnel IP must error")
parsedErr, ok := status.FromError(err)
require.True(t, ok, "error must be a status error")
require.Equal(t, status.NotFound, parsedErr.Type(), "tunnel-IP miss must be NotFound, not Internal")
})
}
func TestSqlStore_SavePeer(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)

View File

@@ -21,6 +21,8 @@ const (
httpRequestCounterPrefix = "management.http.request.counter"
httpResponseCounterPrefix = "management.http.response.counter"
httpRequestDurationPrefix = "management.http.request.duration.ms"
RequestIDHeader = "X-Request-Id"
)
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
@@ -172,6 +174,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
reqID := xid.New().String()
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
//nolint
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
rw.Header().Set(RequestIDHeader, reqID)
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)

View File

@@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const (
@@ -272,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
peerGroups := a.GetPeerGroups(peerID)
zonesByCluster := map[string]*nbdns.CustomZone{}
zonesByApex := map[string]*nbdns.CustomZone{}
for _, svc := range a.Services {
if svc == nil || !svc.Enabled || !svc.Private {
@@ -289,19 +290,24 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
continue
}
zone, exists := zonesByCluster[svc.ProxyCluster]
serviceDomainZone := a.privateServiceDomainZone(svc)
if serviceDomainZone == "" {
continue
}
zone, exists := zonesByApex[serviceDomainZone]
if !exists {
// NonAuthoritative makes this a match-only zone: queries for
// names without an explicit record fall through to the
// upstream resolver instead of returning NXDOMAIN. Without
// it, adding a single private service would black-hole every
// other name under the cluster apex.
// other name under the zone apex.
zone = &nbdns.CustomZone{
Domain: dns.Fqdn(svc.ProxyCluster),
Domain: dns.Fqdn(serviceDomainZone),
Records: []nbdns.SimpleRecord{},
NonAuthoritative: true,
}
zonesByCluster[svc.ProxyCluster] = zone
zonesByApex[serviceDomainZone] = zone
}
emitted := 0
@@ -339,8 +345,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
}
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
for _, zone := range zonesByCluster {
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
for _, zone := range zonesByApex {
if len(zone.Records) == 0 {
continue
}
@@ -356,6 +362,33 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
return out
}
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
// looking at the proxy cluster domain then the custom domains.
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
return svc.ProxyCluster
}
// Longest matching custom domain wins
zoneName := ""
for _, d := range a.Domains {
if d == nil || d.TargetCluster != svc.ProxyCluster {
continue
}
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
zoneName = d.Domain
}
}
return zoneName
}
func domainFromSuffix(domain, suffix string) bool {
if suffix == "" {
return false
}
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
}
// peerInDistributionGroups reports whether any of the peer's groups
// matches the service's bearer-auth distribution_groups.
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {
@@ -1804,7 +1837,7 @@ func shouldCheckRulesForNativeSSH(supportsNative bool, rule *PolicyRule, peer *n
// peerSupportedFirewallFeatures checks if the peer version supports port ranges.
func peerSupportedFirewallFeatures(peerVer string) supportedFeatures {
if strings.Contains(peerVer, "dev") {
if version.IsDevelopmentVersion(peerVer) {
return supportedFeatures{true, true}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -234,6 +235,113 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
}
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
account := privateZoneTestAccount(t)
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
// is the cluster serving it, and account.Domains holds the registered
// custom domain. The synth zone apex must be the registered domain,
// not the cluster, or the client's match-only zone never intercepts
// the query.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
zone := zones[0]
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
rec := zone.Records[0]
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
}
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "custom",
Domain: "app.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services[0].Domain = "a.example.com"
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "bapp",
Domain: "b.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
}
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// Service domain is outside the cluster and no account.Domains entry
// covers it: there is no apex that would intercept the query, so the
// service must be skipped rather than emit an unmatchable record.
account.Services[0].Domain = "app.example.com"
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
}
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// The registered custom domain matches the service FQDN by suffix but
// targets a different cluster than the service's ProxyCluster. It must
// be ignored, leaving no apex to intercept the query — otherwise the
// zone would point at this cluster's proxy peers under a domain owned
// by a different cluster.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services = append(account.Services, &service.Service{
@@ -254,3 +362,72 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
}
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
privateService := func(id, domain string) *service.Service {
return &service.Service{
ID: id,
AccountID: "acct-1",
Name: id,
Domain: domain,
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
}
}
publicService := func(id, domain string) *service.Service {
s := privateService(id, domain)
s.Private = false
return s
}
account.Services = []*service.Service{
// 3 private services under the cluster suffix.
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
// 4 private services under the custom domain suffix.
privateService("custom-1", "custom1.example.com"),
privateService("custom-2", "custom2.example.com"),
privateService("custom-3", "custom3.example.com"),
privateService("custom-4", "custom4.example.com"),
// 2 public services, one per suffix, must not surface.
publicService("public-cluster", "public.eu.proxy.netbird.io"),
publicService("public-custom", "public.example.com"),
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
clusterNames := recordNames(cluster)
assert.ElementsMatch(t,
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
clusterNames,
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
customNames := recordNames(custom)
assert.ElementsMatch(t,
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
customNames,
"only the 4 private custom services surface in the custom zone (public one excluded)")
}
// recordNames returns the record names of a zone for order-independent assertions.
func recordNames(zone nbdns.CustomZone) []string {
names := make([]string, 0, len(zone.Records))
for _, r := range zone.Records {
names = append(names, r.Name)
}
return names
}

View File

@@ -646,41 +646,7 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
expectedPorts: []string{"20-25", "10-100", "22022"},
},
{
name: "dev suffix version supports all features",
peer: &nbpeer.Peer{
ID: "peer1",
SSHEnabled: true,
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.50.0-dev",
Flags: nbpeer.Flags{ServerSSHAllowed: true},
},
},
rule: &PolicyRule{
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
},
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
expectedPorts: []string{"22", "22022"},
},
{
name: "dev suffix version supports all features",
peer: &nbpeer.Peer{
ID: "peer1",
SSHEnabled: true,
Meta: nbpeer.PeerSystemMeta{
WtVersion: "dev",
Flags: nbpeer.Flags{ServerSSHAllowed: true},
},
},
rule: &PolicyRule{
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
},
base: FirewallRule{PeerIP: "10.0.0.1", Direction: 0, Action: "accept", Protocol: "tcp"},
expectedPorts: []string{"22", "22022"},
},
{
name: "development suffix version supports all features",
name: "development version supports all features",
peer: &nbpeer.Peer{
ID: "peer1",
SSHEnabled: true,

View File

@@ -557,7 +557,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
@@ -628,9 +627,14 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
rules := []*RouteFirewallRule{&rule}
if includeIPv6 && r.IsDynamic() {
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
if isDefaultV4 {
ruleV6.Destination = "::/0"
ruleV6.RouteID = r.ID + "-v6-default"
}
rules = append(rules, &ruleV6)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"testing"
"time"
@@ -1029,6 +1030,48 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
}
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
// (::/0 source and destination) for peers that support IPv6, mirroring the route
// the client installs. Without it, IPv6 traffic is routed to the exit node but
// dropped at the forward chain.
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
account, validatedPeers := scalableTestAccount(20, 2)
routingPeerID := "peer-5"
routingPeer := account.Peers[routingPeerID]
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
account.Routes["route-exit"] = &route.Route{
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
PeerID: routingPeerID, Peer: routingPeer.Key,
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
AccessControlGroups: []string{},
AccountID: "test-account",
}
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
require.NotNil(t, nm)
hasV4 := false
hasV6 := false
for _, rfr := range nm.RoutesFirewallRules {
switch rfr.Destination {
case "0.0.0.0/0":
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
hasV4 = true
}
case "::/0":
if slices.Contains(rfr.SourceRanges, "::/0") {
hasV6 = true
}
}
}
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
}
// ──────────────────────────────────────────────────────────────────────────────
// 15. MULTIPLE ROUTERS PER NETWORK
// ──────────────────────────────────────────────────────────────────────────────

View File

@@ -214,7 +214,10 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid --trusted-proxies: %w", err)
}
srv := proxy.New(proxy.Config{
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()
srv := proxy.New(ctx, proxy.Config{
ListenAddr: addr,
Logger: logger,
Version: Version,
@@ -246,14 +249,12 @@ func runServer(cmd *cobra.Command, args []string) error {
Private: private,
MaxDialTimeout: maxDialTimeout,
MaxSessionIdleTimeout: maxSessionIdleTimeout,
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
GeoDataDir: geoDataDir,
CrowdSecAPIURL: crowdsecAPIURL,
CrowdSecAPIKey: crowdsecAPIKey,
})
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()
return srv.ListenAndServe(ctx, addr)
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
stdlog "log"
"net"
"net/http"
@@ -42,7 +43,7 @@ const privateInboundPortHTTPS = 443
const privateInboundPortHTTP = 80
// inboundManager wires per-account inbound listeners into the proxy
// pipeline when --private-inbound is enabled. When disabled the manager
// pipeline when --private is enabled. When disabled the manager
// is nil and every method on *Server that touches it short-circuits.
type inboundManager struct {
logger *log.Logger
@@ -55,15 +56,18 @@ type inboundManager struct {
}
// inboundEntry owns the listeners, router and HTTP servers for a single
// account's embedded netstack.
// account's embedded netstack. errorLogWriters retain the logrus pipe
// writers backing each http.Server's ErrorLog so tearDown can close
// them — otherwise the pipe + its scanner goroutine leak per account.
type inboundEntry struct {
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
cancel context.CancelFunc
wg sync.WaitGroup
router *nbtcp.Router
tlsListener net.Listener
plainListener net.Listener
httpsServer *http.Server
httpServer *http.Server
errorLogWriters []*io.PipeWriter
cancel context.CancelFunc
wg sync.WaitGroup
}
// pendingInboundRoute holds a route that arrived before the account's
@@ -147,30 +151,34 @@ func (m *inboundManager) bringUp(ctx context.Context, accountID types.AccountID,
return types.WithOverlayOrigin(ctx)
}
httpsErrLog, httpsErrW := newInboundErrorLog(m.logger, "https", accountID)
httpErrLog, httpErrW := newInboundErrorLog(m.logger, "http", accountID)
httpsServer := &http.Server{
Handler: scopedHandler,
TLSConfig: m.tlsConfig,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "https", accountID),
ErrorLog: httpsErrLog,
ConnContext: markOverlayOrigin,
}
httpServer := &http.Server{
Handler: scopedHandler,
ReadHeaderTimeout: httpInboundReadHeaderTimeout,
IdleTimeout: httpInboundIdleTimeout,
ErrorLog: newInboundErrorLog(m.logger, "http", accountID),
ErrorLog: httpErrLog,
ConnContext: markOverlayOrigin,
}
runCtx, cancel := context.WithCancel(ctx)
entry := &inboundEntry{
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
cancel: cancel,
router: router,
tlsListener: tlsListener,
plainListener: plainListener,
httpsServer: httpsServer,
httpServer: httpServer,
errorLogWriters: []*io.PipeWriter{httpsErrW, httpErrW},
cancel: cancel,
}
entry.wg.Add(1)
@@ -237,6 +245,14 @@ func (m *inboundManager) tearDown(accountID types.AccountID, entry *inboundEntry
m.logger.Debugf("close per-account plain listener: %v", err)
}
entry.wg.Wait()
// Close the ErrorLog pipes only after the http.Servers have fully
// stopped so any straggling stdlib write doesn't race with the
// close. Each writer also tears down the logrus scanner goroutine.
for _, w := range entry.errorLogWriters {
if err := w.Close(); err != nil {
m.logger.Debugf("close per-account inbound error log writer: %v", err)
}
}
}
// AddRoute records an SNI/host route on the account's per-account router.
@@ -374,7 +390,7 @@ func (m *inboundManager) ListenerInfo(accountID types.AccountID) (InboundListene
}
// Snapshot returns the inbound listener state for every account that has
// a live listener at call time. Empty when --private-inbound is off or
// a live listener at call time. Empty when --private is off or
// no accounts have come up yet.
func (m *inboundManager) Snapshot() map[types.AccountID]InboundListenerInfo {
if m == nil {
@@ -497,7 +513,7 @@ func accountTunnelLookup(client *embed.Client) auth.TunnelLookupFunc {
// peerstore lookup to every request's context before delegating to next.
// Calling on the host-level listener is a no-op because that path never
// installs this wrapper, so the existing behaviour stays byte-for-byte
// identical when --private-inbound is off or the request didn't arrive
// identical when --private is off or the request didn't arrive
// on a per-account listener.
func withTunnelLookup(next http.Handler, lookup auth.TunnelLookupFunc) http.Handler {
if lookup == nil {
@@ -538,10 +554,14 @@ func (a inboundDebugAdapter) InboundListeners() map[types.AccountID]debug.Inboun
}
// newInboundErrorLog routes a per-account http.Server's stdlib error
// stream through logrus at warn level.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) *stdlog.Logger {
return stdlog.New(logger.WithFields(log.Fields{
// stream through logrus at warn level. The returned PipeWriter must be
// closed by the caller (tearDown) once the http.Server has shut down —
// otherwise the pipe and its scanner goroutine leak per account, see
// logrus.Entry.WriterLevel.
func newInboundErrorLog(logger *log.Logger, scheme string, accountID types.AccountID) (*stdlog.Logger, *io.PipeWriter) {
w := logger.WithFields(log.Fields{
"inbound-http": scheme,
"account_id": accountID,
}).WriterLevel(log.WarnLevel), "", 0)
}).WriterLevel(log.WarnLevel)
return stdlog.New(w, "", 0), w
}

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
@@ -110,7 +111,7 @@ func TestServer_PrivateInbound_Enabled_WiresLifecycle(t *testing.T) {
// Construct a NetBird transport. We can't actually start the embedded
// client here (that needs a real management server), but we can
// confirm that the lifecycle callbacks are registered.
s.netbird = roundtrip.NewNetBird("test", "test", roundtrip.ClientConfig{
s.netbird = roundtrip.NewNetBird(t.Context(), "test", "test", roundtrip.ClientConfig{
MgmtAddr: "http://invalid.test",
}, quietLogger(), nil, fakeMgmtClient{})
@@ -139,7 +140,7 @@ func TestInboundManager_AddRouteAfterReady_RegistersDirectly(t *testing.T) {
// TestPrivateCapability_DerivedFromPrivateOnly tests that the capability
// bit reported upstream tracks --private exclusively. The previous
// --private-inbound flag has been folded into --private.
// --private flag has been folded into --private.
func TestPrivateCapability_DerivedFromPrivateOnly(t *testing.T) {
tests := []struct {
name string
@@ -318,7 +319,7 @@ func TestInboundManager_ListenerInfo(t *testing.T) {
}
// TestInboundManager_NilManagerSafe ensures the observability accessors
// are safe to call when --private-inbound is off (nil manager).
// are safe to call when --private is off (nil manager).
func TestInboundManager_NilManagerSafe(t *testing.T) {
var mgr *inboundManager
_, ok := mgr.ListenerInfo("anything")
@@ -482,6 +483,38 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12} //nolint:gosec
}
// TestNewInboundErrorLog_WriterIsCloseable guards the close path on the
// logrus PipeWriter that backs each per-account http.Server's ErrorLog.
// logrus.Entry.WriterLevel returns an *io.PipeWriter that owns a pipe +
// scanner goroutine; the caller must Close() it on teardown or the
// resources leak per account. The contract is verified two ways:
//
// - the constructor returns a non-nil writer the caller can keep,
// - writing to the writer after Close() fails with io.ErrClosedPipe,
// which is the only externally observable sign that Close was wired.
//
// A leaking refactor (forgetting to thread the writer to tearDown, or
// dropping the Close call) would still pass this test individually but
// fail an integration goleak check; this unit test is the cheap first
// line of defence.
func TestNewInboundErrorLog_WriterIsCloseable(t *testing.T) {
logger := quietLogger()
stdLog, writer := newInboundErrorLog(logger, "https", types.AccountID("acct-1"))
require.NotNil(t, stdLog, "newInboundErrorLog must return a non-nil *log.Logger")
require.NotNil(t, writer, "newInboundErrorLog must return the underlying PipeWriter so tearDown can Close it")
// First Close succeeds.
require.NoError(t, writer.Close(), "PipeWriter.Close should succeed the first time")
// After Close, the writer must refuse new writes — that's the only
// behavioural signal that the pipe (and its scanner goroutine) has
// shut down.
_, err := writer.Write([]byte("post-close write\n"))
require.ErrorIs(t, err, io.ErrClosedPipe,
"writes after Close must surface io.ErrClosedPipe so callers know the writer is gone")
}
// testCertPEM / testKeyPEM are a minimal RSA self-signed cert for
// 127.0.0.1 — only used by tests that need a working TLS handshake.
var testCertPEM = []byte(`-----BEGIN CERTIFICATE-----

View File

@@ -346,13 +346,15 @@ func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Re
// management unreachable, peer unknown, user not in group) returns false so
// the caller falls back to the existing OIDC scheme dispatch.
//
// Phase 3 adds a local-first short-circuit: when the request arrived on a
// per-account inbound listener the context carries a peerstore lookup
// (TunnelLookupFromContext). If the lookup says the IP isn't in the account's
// roster the proxy denies fast without calling management. If the lookup
// confirms a known peer the RPC still runs for the user-identity tail
// (UserID + group access), but its result is cached for tunnelCacheTTL so
// repeat requests skip management entirely.
// The fast-path is gated on TunnelLookupFromContext(r.Context()) being
// present — that context value is attached only by the per-account
// inbound (overlay) listener. The host listener never sets it, so a
// public client whose source IP happens to fall inside an RFC1918 / ULA
// / CGNAT range can't impersonate a mesh peer by colliding with a
// tunnel-IP. Once we know the request arrived over WireGuard the
// per-account peerstore lookup is consulted: a miss denies fast (no
// management round-trip), a hit gates the cached ValidateTunnelPeer RPC
// that mints the session JWT.
func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
if mw.sessionValidator == nil {
return false
@@ -361,18 +363,24 @@ func (mw *Middleware) forwardWithTunnelPeer(w http.ResponseWriter, r *http.Reque
if !clientIP.IsValid() {
return false
}
// Anti-spoof: only honour the tunnel-peer fast-path on requests that
// were stamped by an overlay listener. Without that marker an
// attacker could send a request from a colliding RFC1918 / CGNAT
// source on the public listener and bypass operator auth.
lookup := TunnelLookupFromContext(r.Context())
if lookup == nil {
return false
}
if !isTunnelSourceIP(clientIP) {
return false
}
if lookup := TunnelLookupFromContext(r.Context()); lookup != nil {
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
if _, ok := lookup(clientIP); !ok {
mw.logger.WithFields(log.Fields{
"host": host,
"remote": clientIP,
}).Debug("local peerstore: tunnel IP not in account roster; denying without RPC")
return false
}
resp, _, err := mw.tunnelCache.fetch(r.Context(), tunnelCacheKey{

View File

@@ -1227,3 +1227,93 @@ func TestProtect_NonOIDCSchemes_PlainHTTP_NotBlocked(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, rec.Code, "PIN-only domain should serve the login page on plain HTTP")
}
// stubTunnelValidator records ValidateTunnelPeer calls so a test can
// assert whether the fast-path reached management.
type stubTunnelValidator struct {
called bool
resp *proto.ValidateTunnelPeerResponse
}
func (s *stubTunnelValidator) ValidateSession(context.Context, *proto.ValidateSessionRequest, ...grpc.CallOption) (*proto.ValidateSessionResponse, error) {
return nil, errors.New("not used in this test")
}
func (s *stubTunnelValidator) ValidateTunnelPeer(context.Context, *proto.ValidateTunnelPeerRequest, ...grpc.CallOption) (*proto.ValidateTunnelPeerResponse, error) {
s.called = true
return s.resp, nil
}
// TestProtect_TunnelPeerFastPath_RequiresInboundMarker guards the
// anti-spoof gate: a request with an RFC1918 source IP arriving on the
// public listener (no TunnelLookupFromContext attached) must not be
// allowed to take the tunnel-peer fast-path. Without this gate a public
// client whose source IP happens to fall inside an RFC1918 range could
// bypass the configured auth scheme by colliding with a known tunnel
// IP.
func TestProtect_TunnelPeerFastPath_RequiresInboundMarker(t *testing.T) {
validator := &stubTunnelValidator{
resp: &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "should-not-be-used",
UserId: "user-1",
},
}
mw := NewMiddleware(log.StandardLogger(), validator, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
// Request from an RFC1918 source IP on the public listener — no
// TunnelLookupFromContext attached. The fast-path must reject this
// and fall through to the PIN scheme (which renders 401 on plain
// HTTP for a non-authenticated request).
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.64.0.5:5000"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, validator.called,
"ValidateTunnelPeer must not be invoked when the request lacks the inbound TunnelLookup marker")
assert.Equal(t, http.StatusUnauthorized, rec.Code,
"without the inbound marker the request must fall through to the operator auth scheme")
}
// TestProtect_TunnelPeerFastPath_TakesPathWithInboundMarker verifies
// the positive side: a request marked as overlay-origin (carrying the
// TunnelLookup context value) and matching a tunnel-IP range does take
// the fast-path and reach management.
func TestProtect_TunnelPeerFastPath_TakesPathWithInboundMarker(t *testing.T) {
validator := &stubTunnelValidator{
resp: &proto.ValidateTunnelPeerResponse{
Valid: true,
SessionToken: "tunnel-session-token",
UserId: "user-1",
},
}
mw := NewMiddleware(log.StandardLogger(), validator, nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "", nil, false))
handler := mw.Protect(newPassthroughHandler())
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.RemoteAddr = "100.64.0.5:5000"
req = req.WithContext(WithTunnelLookup(req.Context(), lookup))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.True(t, validator.called,
"ValidateTunnelPeer must run when the request carries the inbound TunnelLookup marker")
assert.Equal(t, http.StatusOK, rec.Code,
"a successful tunnel-peer validation must forward to the next handler")
}

View File

@@ -101,7 +101,10 @@ func TestForwardWithTunnelPeer_GroupsPropagateToCapturedData(t *testing.T) {
w, r := newTunnelRequest("100.64.0.10:55555")
cd := proxy.NewCapturedData("")
r = r.WithContext(proxy.WithCapturedData(r.Context(), cd))
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
r = r.WithContext(proxy.WithCapturedData(WithTunnelLookup(r.Context(), lookup), cd))
called := false
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { called = true })
@@ -148,9 +151,13 @@ func TestForwardWithTunnelPeer_LocalLookupKnownPeerStillRPCs(t *testing.T) {
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "RPC must run for the user-identity tail when local lookup confirms the peer")
}
// TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath ensures the existing
// behaviour stays intact on the host-level listener (no lookup attached).
func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
// TestForwardWithTunnelPeer_NoLookupRefusesFastPath guards the
// anti-spoof gate: requests that didn't arrive on the per-account
// inbound listener (no TunnelLookup attached) must never reach
// management's ValidateTunnelPeer, even when the source IP looks like
// a tunnel address. A colliding RFC1918 / CGNAT source on the public
// listener would otherwise impersonate a mesh peer.
func TestForwardWithTunnelPeer_NoLookupRefusesFastPath(t *testing.T) {
validator := &stubSessionValidator{
respFn: func(_ *proto.ValidateTunnelPeerRequest) *proto.ValidateTunnelPeerResponse {
return &proto.ValidateTunnelPeerResponse{Valid: true, SessionToken: "tok", UserId: "user-1"}
@@ -165,9 +172,9 @@ func TestForwardWithTunnelPeer_NoLookupKeepsLegacyPath(t *testing.T) {
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
assert.True(t, handled, "host-level path forwards on positive RPC result")
assert.True(t, called, "next handler runs on host-level success")
assert.Equal(t, int32(1), validator.tunnelCalls.Load(), "host-level path always RPCs (Phase 3 unchanged)")
assert.False(t, handled, "fast-path must refuse without the inbound marker")
assert.False(t, called, "next handler must not run")
assert.Equal(t, int32(0), validator.tunnelCalls.Load(), "ValidateTunnelPeer must not be invoked without the inbound marker")
}
// TestForwardWithTunnelPeer_RPCErrorFallsThrough validates that an RPC
@@ -201,8 +208,13 @@ func TestForwardWithTunnelPeer_CacheReusesPositiveResponse(t *testing.T) {
}
mw := newTunnelMiddleware(t, validator)
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
for i := 0; i < 4; i++ {
w, r := newTunnelRequest("100.64.0.10:55555")
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
config, _ := mw.getDomainConfig("svc.example")
handled := mw.forwardWithTunnelPeer(w, r, "svc.example", config, next)
@@ -226,11 +238,21 @@ func TestForwardWithTunnelPeer_RoutesAccountIDIntoCacheKey(t *testing.T) {
require.NoError(t, mw.AddDomain("svc-a.example", nil, "", 0, "acct-a", "svc-a", nil, false))
require.NoError(t, mw.AddDomain("svc-b.example", nil, "", 0, "acct-b", "svc-b", nil, false))
// The fast-path requires the inbound-listener marker on the context.
// The peerstore lookup itself is account-agnostic at this level
// (one TunnelLookupFunc per account is attached by inbound.go); a
// trivial "always hit" lookup is enough to exercise the cache-key
// branch this test covers.
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
for _, host := range []string{"svc-a.example", "svc-b.example"} {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://"+host+"/", nil)
r.Host = host
r.RemoteAddr = "100.64.0.10:55555"
r = r.WithContext(WithTunnelLookup(r.Context(), lookup))
config, _ := mw.getDomainConfig(host)
handled := mw.forwardWithTunnelPeer(w, r, host, config, http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
require.True(t, handled, "host %s should forward", host)
@@ -314,9 +336,17 @@ func TestPrivateService_ForwardsOnTunnelPeerSuccess(t *testing.T) {
w.WriteHeader(http.StatusOK)
}))
// Per-account inbound listener attaches WithTunnelLookup; without it
// forwardWithTunnelPeer refuses to take the fast-path. Mirror the
// real flow so this test exercises the post-gating success branch.
lookup := TunnelLookupFunc(func(_ netip.Addr) (PeerIdentity, bool) {
return PeerIdentity{}, true
})
req := httptest.NewRequest(http.MethodGet, "https://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.10:55555"
req = req.WithContext(WithTunnelLookup(req.Context(), lookup))
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

View File

@@ -131,7 +131,7 @@ func (h *Handler) SetCertStatus(cs certStatus) {
// SetInboundProvider wires per-account inbound listener observability.
// Pass nil (or skip the call) to keep the inbound section out of debug
// responses on proxies that don't run --private-inbound.
// responses on proxies that don't run --private.
func (h *Handler) SetInboundProvider(p InboundProvider) {
h.inbound = p
}

View File

@@ -66,6 +66,22 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Loop guard for private services: a peer that hosts the target
// dialing its own service URL would round-trip its own traffic
// through the proxy and back over WG to itself. Refuse the request
// with 421 (Misdirected Request) so the caller sees an explicit
// error instead of silently doubling tunnel traffic.
if p.isSelfTargetLoop(r, result.target.URL) {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginNoRoute)
}
requestID := getRequestID(r)
web.ServeErrorPage(w, r, http.StatusMisdirectedRequest, "Loop Detected",
"This peer is the target of the requested service. Reach the backend directly instead of dialing the public service URL from the same machine.",
requestID, web.ErrorStatus{Proxy: true, Destination: false})
return
}
ctx := r.Context()
// Set the account ID in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
@@ -107,6 +123,32 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rp.ServeHTTP(w, r.WithContext(ctx))
}
// isSelfTargetLoop reports whether an overlay-origin request is about to
// be forwarded back to the very peer that initiated it. The detection
// is intentionally narrow: it only fires when the request arrived on
// the per-account inbound (overlay) listener (so we're confident the
// source address is the caller's tunnel IP), and only when the resolved
// target host matches that tunnel IP. Catching this here returns 421 to
// the caller instead of letting the proxy round-trip its own traffic
// over WG twice.
func (p *ReverseProxy) isSelfTargetLoop(r *http.Request, target *url.URL) bool {
if target == nil {
return false
}
if !types.IsOverlayOrigin(r.Context()) {
return false
}
srcIP := extractHostIP(r.RemoteAddr)
if !srcIP.IsValid() {
return false
}
targetIP, err := netip.ParseAddr(target.Hostname())
if err != nil {
return false
}
return srcIP.Unmap() == targetIP.Unmap()
}
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
// inbound requests to target the backend service while setting security-relevant
// forwarding headers and stripping proxy authentication credentials.

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
)
@@ -1285,6 +1286,103 @@ func TestStampNetBirdIdentity_OmitsGroupsHeaderWhenAllInvalid(t *testing.T) {
"X-NetBird-Groups must not be set when every group label is rejected")
}
// nopOKTransport returns 200 for every request without dialing — used
// by the self-target-loop tests so the non-loop cases don't pay a real
// TCP-dial timeout.
type nopOKTransport struct{}
func (nopOKTransport) RoundTrip(*http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody, Header: http.Header{}}, nil
}
// TestServeHTTP_SelfTargetLoopReturns421 covers the loop guard for
// private services: when a peer dials a service whose only target is
// the peer itself, the proxy must refuse with 421 (Misdirected
// Request) rather than round-tripping the request back over WG to
// the same peer.
func TestServeHTTP_SelfTargetLoopReturns421(t *testing.T) {
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
rp.AddMapping(Mapping{
ID: "svc-1",
AccountID: "acct-1",
Host: "private.svc",
Paths: map[string]*PathTarget{
"/": {
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.5:55555"
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rec := httptest.NewRecorder()
rp.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMisdirectedRequest, rec.Code,
"a peer dialing a service whose target is itself must get 421")
}
// TestServeHTTP_SelfTargetLoop_NonOverlayRequestPassesThrough verifies
// the guard is scoped to overlay-origin requests. A public-listener
// request that happens to share a source IP with the target host must
// not be misinterpreted as a loop — the gating relies on the inbound
// marker being attached only by the per-account overlay listener.
func TestServeHTTP_SelfTargetLoop_NonOverlayRequestPassesThrough(t *testing.T) {
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
rp.AddMapping(Mapping{
ID: "svc-1",
AccountID: "acct-1",
Host: "public.svc",
Paths: map[string]*PathTarget{
"/": {
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://public.svc/", nil)
req.Host = "public.svc"
req.RemoteAddr = "100.64.0.5:55555"
// No WithOverlayOrigin → the guard must not fire.
rec := httptest.NewRecorder()
rp.ServeHTTP(rec, req)
assert.NotEqual(t, http.StatusMisdirectedRequest, rec.Code,
"a non-overlay request with a colliding source IP must not be flagged as a loop")
}
// TestServeHTTP_SelfTargetLoop_OverlayDifferentIPPassesThrough confirms
// that overlay-origin requests with a source IP that does *not* match
// the target host are forwarded normally.
func TestServeHTTP_SelfTargetLoop_OverlayDifferentIPPassesThrough(t *testing.T) {
rp := NewReverseProxy(nopOKTransport{}, "auto", nil, nil)
rp.AddMapping(Mapping{
ID: "svc-1",
AccountID: "acct-1",
Host: "private.svc",
Paths: map[string]*PathTarget{
"/": {
URL: &url.URL{Scheme: "http", Host: "100.64.0.5:8080"},
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://private.svc/", nil)
req.Host = "private.svc"
req.RemoteAddr = "100.64.0.99:55555" // different from the target
req = req.WithContext(types.WithOverlayOrigin(req.Context()))
rec := httptest.NewRecorder()
rp.ServeHTTP(rec, req)
assert.NotEqual(t, http.StatusMisdirectedRequest, rec.Code,
"overlay request with a non-matching source IP must not be flagged as a loop")
}
// TestStampNetBirdIdentity_CapturedDataPresentButEmpty covers requests
// that carry CapturedData with no identity fields populated (e.g. the
// auth middleware ran but the request didn't authenticate). Both

View File

@@ -28,6 +28,10 @@ import (
const deviceNamePrefix = "ingress-proxy-"
const clientStopTimeout = 30 * time.Second
const createProxyPeerTimeout = 30 * time.Second
// backendKey identifies a backend by its host:port from the target URL.
type backendKey string
@@ -152,6 +156,7 @@ type managementClient interface {
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple services to share the same connection.
type NetBird struct {
ctx context.Context
proxyID string
proxyAddr string
clientCfg ClientConfig
@@ -161,6 +166,7 @@ type NetBird struct {
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
lifecycleMu sync.Map
initLogOnce sync.Once
statusNotifier statusNotifier
// readyHandler runs after the embedded client for an account reports
@@ -176,6 +182,10 @@ type NetBird struct {
// (i.e. when a new client was actually created, not when an existing one
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
OnAddPeer func(d time.Duration, err error)
// startClient runs the post-create client startup. Nil uses runClientStartup;
// tests override it to avoid a real embed client.Start.
startClient func(accountID types.AccountID, client *embed.Client)
}
// ClientDebugInfo contains debug information about a client.
@@ -199,27 +209,20 @@ type skipTLSVerifyContextKey struct{}
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
si := serviceInfo{serviceID: serviceID}
n.clientsMux.Lock()
if n.registerExistingClient(accountID, key, si) {
return nil
}
entry, exists := n.clients[accountID]
if exists {
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
if n.registerExistingClient(accountID, key, si) {
return nil
}
@@ -229,10 +232,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
n.OnAddPeer(time.Since(createStart), err)
}
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clientsMux.Lock()
n.clients[accountID] = entry
n.clientsMux.Unlock()
@@ -241,15 +244,64 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
"service_key": key,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
transferred = true
go func() {
defer lifecycle.Unlock()
n.startClientStartup(accountID, entry.client)
}()
return nil
}
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
if n.startClient != nil {
n.startClient(accountID, client)
return
}
n.runClientStartup(accountID, client)
}
// registerExistingClient registers the service against an already-present
// client for the account and returns true when it did. It notifies management
// of the new service when the client is already started.
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.Unlock()
return false
}
entry.services[key] = si
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).Debug("registered service with existing client")
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": key,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return true
}
// accountLifecycle returns the per-account lifecycle mutex, serialising client
// creation against teardown so a slow client.Stop cannot race a new
// client.Start for the same account, without blocking clientsMux.
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
return mu.(*sync.Mutex)
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
// and creates an embedded NetBird client. Must be called with the account's
// lifecycle mutex held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
serviceID := si.serviceID
n.logger.WithFields(log.Fields{
@@ -269,7 +321,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
defer cancel()
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
ServiceId: string(serviceID),
AccountId: string(accountID),
Token: authToken,
@@ -307,7 +361,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: n.clientCfg.BlockInbound,
BlockInbound: n.clientCfg.BlockInbound,
// The embedded proxy peer must never be a stepping stone into
// the proxy host's LAN: it only exists to reach NetBird mesh
// targets or, when direct_upstream is set, the host network
@@ -355,8 +409,14 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
}, nil
}
// runClientStartup starts the client and notifies registered services on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
// runClientStartup starts the client and notifies registered services on
// success. This function runs in a goroutine launched from AddPeer, so it
// must never inherit the caller's request-scoped context — a canceled
// request must not abort the inbound listener bring-up or the management
// status notification. The embedded client.Start gets its own bounded
// startCtx; once Start succeeds, notifyClientReady takes over with a
// fresh context.Background() (see that function for the contract).
func (n *NetBird) runClientStartup(accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -369,7 +429,17 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
// Mark client as started and collect services to notify outside the lock.
n.notifyClientReady(accountID, client)
}
// notifyClientReady marks the account's client as started, fires the
// readyHandler hook, and notifies management of the new tunnel
// connection for every registered service. It is split out of
// runClientStartup so a regression test can drive the post-Start tail
// without needing a live embedded client. The contract that the
// hooks/notifier see context.Background() — never the AddPeer caller's
// ctx — lives here.
func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Client) {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
@@ -385,7 +455,7 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
n.clientsMux.Unlock()
if readyHandler != nil {
state := readyHandler(ctx, accountID, client)
state := readyHandler(n.ctx, accountID, client)
n.clientsMux.Lock()
if e, ok := n.clients[accountID]; ok {
e.inbound = state
@@ -404,7 +474,7 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
return
}
for _, sn := range toNotify {
if err := n.statusNotifier.NotifyStatus(ctx, accountID, sn.serviceID, true); err != nil {
if err := n.statusNotifier.NotifyStatus(n.ctx, accountID, sn.serviceID, true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_key": sn.key,
@@ -421,6 +491,15 @@ func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountI
// RemovePeer unregisters a service from an account. The client is only stopped
// when no services are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
lifecycle := n.accountLifecycle(accountID)
lifecycle.Lock()
transferred := false
defer func() {
if !transferred {
lifecycle.Unlock()
}
}()
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
@@ -443,17 +522,8 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
delete(entry.services, key)
stopClient := len(entry.services) == 0
var client *embed.Client
var transport, insecureTransport *http.Transport
var inbound any
var stopHandler func(types.AccountID, any)
if stopClient {
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
client = entry.client
transport = entry.transport
insecureTransport = entry.insecureTransport
inbound = entry.inbound
stopHandler = n.stopHandler
delete(n.clients, accountID)
} else {
n.logger.WithFields(log.Fields{
@@ -467,19 +537,40 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
if stopClient {
if inbound != nil && stopHandler != nil {
stopHandler(accountID, inbound)
}
transport.CloseIdleConnections()
insecureTransport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
transferred = true
go n.stopClientLocked(accountID, lifecycle, entry)
}
return nil
}
// stopClientLocked releases a client's resources off the caller's goroutine so a
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
// synchronously). It unlocks lifecycle when done so a new client.Start for the
// same account waits for this teardown.
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
defer lifecycle.Unlock()
if entry.inbound != nil && n.stopHandler != nil {
n.stopHandler(accountID, entry.inbound)
}
if entry.transport != nil {
entry.transport.CloseIdleConnections()
}
if entry.insecureTransport != nil {
entry.insecureTransport.CloseIdleConnections()
}
if entry.client == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
defer cancel()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
}
}
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
if n.statusNotifier == nil {
return
@@ -666,11 +757,12 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random
// OS-assigned port. A fixed port only works with single-account deployments;
// multiple accounts will fail to bind the same port.
func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
func NewNetBird(ctx context.Context, proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
ctx: ctx,
proxyID: proxyID,
proxyAddr: proxyAddr,
clientCfg: clientCfg,

View File

@@ -6,11 +6,13 @@ import (
"net/netip"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -21,6 +23,18 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
// tests can detect AddPeer reaching client creation.
type signalMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
m.once.Do(func() { close(m.entered) })
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
@@ -30,12 +44,15 @@ type statusCall struct {
accountID types.AccountID
serviceID types.ServiceID
connected bool
// ctx is captured so tests can assert the notifier received a
// fresh background context rather than an inherited request ctx.
ctx context.Context
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
func (m *mockStatusNotifier) NotifyStatus(ctx context.Context, accountID types.AccountID, serviceID types.ServiceID, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected})
m.statuses = append(m.statuses, statusCall{accountID, serviceID, connected, ctx})
return nil
}
@@ -48,11 +65,15 @@ func (m *mockStatusNotifier) calls() []statusCall {
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird("test-proxy", "invalid.test", ClientConfig{
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, nil, &mockMgmtClient{})
// Skip the real embed client.Start, which would hang against the unreachable
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
nb.startClient = func(types.AccountID, *embed.Client) {}
return nb
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
@@ -279,11 +300,12 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("account-1")
// Add first service — creates a new client entry.
@@ -295,8 +317,12 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second service — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
// Add second service with an already-cancelled caller context —
// should notify immediately (client is started) AND the notification
// must not inherit the cancelled ctx.
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
err = nb.AddPeer(cancelledCtx, accountID, "domain2.test", "key-1", types.ServiceID("svc-2"))
require.NoError(t, err)
calls := notifier.calls()
@@ -304,6 +330,9 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
assert.Equal(t, accountID, calls[0].accountID)
assert.Equal(t, types.ServiceID("svc-2"), calls[0].serviceID)
assert.True(t, calls[0].connected)
require.NotNil(t, calls[0].ctx, "NotifyStatus must receive a context")
require.NoError(t, calls[0].ctx.Err(),
"already-started NotifyStatus must use a background ctx, not the cancelled caller ctx")
}
// TestNetBird_IdentityForIP_UnknownAccountReturnsFalse confirms that the
@@ -338,7 +367,7 @@ func TestClientEntry_IdentityForIP_InvalidIPReturnsFalse(t *testing.T) {
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
@@ -360,3 +389,164 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
assert.Equal(t, types.ServiceID("svc-1"), calls[0].serviceID)
assert.False(t, calls[0].connected)
}
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
// stall: RemovePeer must return promptly even when the client teardown blocks,
// because teardown runs off the caller's goroutine. The receive loop calls
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
accountID := types.AccountID("acct-async-teardown")
key := DomainServiceKey("svc.example")
teardownEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
close(teardownEntered)
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
done := make(chan error, 1)
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
select {
case err := <-done:
require.NoError(t, err)
case <-time.After(2 * time.Second):
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
}
select {
case <-teardownEntered:
case <-time.After(2 * time.Second):
t.Fatal("teardown never ran")
}
close(releaseTeardown)
}
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
// new client bringup behind an in-flight teardown for the same account, so a
// slow client.Stop can never race a new client.Start for that account.
//
// It targets the handoff race specifically: AddPeer is launched immediately
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
// This only passes if RemovePeer acquires the lifecycle lock synchronously
// (before returning) and hands it to the teardown goroutine — if the goroutine
// acquired the lock itself, AddPeer could win the lock in this window and start
// a replacement client while the old teardown is still pending.
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
nb.startClient = func(types.AccountID, *embed.Client) {}
accountID := types.AccountID("acct-serialize")
key := DomainServiceKey("svc.example")
addEntered := make(chan struct{})
releaseTeardown := make(chan struct{})
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
// Block teardown until released. If AddPeer ever reaches createClientEntry
// (signalled via the mgmt client below) while we hold the lock, the lock
// failed to serialise and the test fails before we release.
<-releaseTeardown
})
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
started: true,
inbound: struct{}{},
}
nb.clientsMux.Unlock()
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
// AddPeer got past the lifecycle lock and into client creation.
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
// Launch AddPeer with NO synchronisation against the teardown goroutine.
addReturned := make(chan struct{})
go func() {
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
close(addReturned)
}()
select {
case <-addEntered:
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
case <-addReturned:
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
case <-time.After(300 * time.Millisecond):
}
close(releaseTeardown)
select {
case <-addReturned:
case <-time.After(2 * time.Second):
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
}
}
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
// a fresh context.Background() rather than inheriting the AddPeer
// caller's request- or stream-scoped ctx. Without this, a cancelled
// caller ctx could abort the inbound listener bring-up or cause the
// management status notification to fail spuriously and leave the
// account in a half-connected state.
func TestNotifyClientReady_UsesBackgroundCtx(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("acct-async")
// Pre-populate a client entry so notifyClientReady has something
// to mark started + something to enumerate for NotifyStatus.
nb.clientsMux.Lock()
nb.clients[accountID] = &clientEntry{
services: map[ServiceKey]serviceInfo{
DomainServiceKey("svc.example"): {serviceID: types.ServiceID("svc-1")},
},
}
nb.clientsMux.Unlock()
var capturedReadyCtx context.Context
nb.SetClientLifecycle(
func(ctx context.Context, _ types.AccountID, _ *embed.Client) any {
capturedReadyCtx = ctx
return nil
},
nil,
)
// Drive the post-Start path directly; a real client.Start would
// need a working management URL.
nb.notifyClientReady(accountID, nil)
require.NotNil(t, capturedReadyCtx, "readyHandler must have been invoked")
require.NoError(t, capturedReadyCtx.Err(),
"readyHandler must receive a background context, not an inherited cancelled one")
deadline, ok := capturedReadyCtx.Deadline()
assert.False(t, ok, "readyHandler ctx must have no deadline (background); got %v", deadline)
calls := notifier.calls()
require.Len(t, calls, 1, "NotifyStatus must be invoked once per registered service")
require.NotNil(t, calls[0].ctx, "NotifyStatus must receive a context")
require.NoError(t, calls[0].ctx.Err(),
"NotifyStatus must receive a background context, not an inherited cancelled one")
}

View File

@@ -1781,11 +1781,14 @@ func TestRouter_PlainHTTP_RoutesToPlainChannel(t *testing.T) {
}
}()
tlsListener, ok := router.HTTPListener().(*chanListener)
require.True(t, ok, "router.HTTPListener() must be the test's chanListener; the test relies on observing its channel directly")
select {
case conn := <-acceptDone:
require.NotNil(t, conn)
_ = conn.Close()
case <-router.HTTPListener().(*chanListener).ch:
case <-tlsListener.ch:
t.Fatal("plain HTTP request leaked into TLS channel")
case <-time.After(3 * time.Second):
t.Fatal("plain HTTP connection never reached plain channel")

View File

@@ -1,6 +1,7 @@
package proxy
import (
"context"
"net/netip"
"time"
@@ -20,14 +21,17 @@ import (
type Config struct {
// ListenAddr is the TCP address the main listener binds. Required.
ListenAddr string
// ID identifies this proxy instance to management. Empty value lets
// New generate a timestamped default.
// ID identifies this proxy instance to management. Empty values are
// replaced with a timestamped default at Server.Start time (see
// initDefaults), not in New.
ID string
// Logger is the logrus logger used everywhere. Empty value falls back
// to log.StandardLogger().
// Logger is the logrus logger used everywhere. Empty values fall
// back to log.StandardLogger() at Server.Start time (see
// initDefaults), not in New.
Logger *log.Logger
// Version is the build version string reported to management. Empty
// becomes "dev".
// values are replaced with "dev" at Server.Start time (see
// initDefaults), not in New.
Version string
// ProxyURL is the public address operators use to reach this proxy.
ProxyURL string
@@ -110,6 +114,10 @@ type Config struct {
MaxDialTimeout time.Duration
// MaxSessionIdleTimeout caps the per-service session idle timeout.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// being applied before the receive loop reconnects to resync. Zero falls
// back to the internal default.
MappingBatchWatchdog time.Duration
// GeoDataDir is the directory containing GeoLite2 MMDB files.
GeoDataDir string
@@ -125,8 +133,9 @@ type Config struct {
// bound — call Start to bring the proxy up. Returning a fully-formed
// Server keeps the standalone code path (which still constructs Server
// directly) byte-for-byte equivalent.
func New(cfg Config) *Server {
func New(ctx context.Context, cfg Config) *Server {
return &Server{
ctx: ctx,
ListenAddr: cfg.ListenAddr,
ID: cfg.ID,
Logger: cfg.Logger,
@@ -159,6 +168,7 @@ func New(cfg Config) *Server {
Private: cfg.Private,
MaxDialTimeout: cfg.MaxDialTimeout,
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
GeoDataDir: cfg.GeoDataDir,
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
CrowdSecAPIKey: cfg.CrowdSecAPIKey,

282
proxy/mapping_stall_test.go Normal file
View File

@@ -0,0 +1,282 @@
package proxy
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// blockingMgmtClient implements roundtrip's managementClient interface.
// CreateProxyPeer parks until release is closed, signalling entry on entered.
// This reproduces the confirmed real-world stall: createClientEntry calls
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
// receive loop calls that path synchronously inside processMappings.
type blockingMgmtClient struct {
entered chan struct{}
once sync.Once
}
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
b.once.Do(func() { close(b.entered) })
// Park until the caller's context is cancelled. In production this ctx is
// the gRPC mapping-stream context with no per-call timeout, so a slow or
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
<-ctx.Done()
return nil, ctx.Err()
}
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
// pre-seeded list of messages, then records how many times Recv advanced. It
// lets the test observe whether the single-threaded receive loop ever gets
// past the first (blocking) batch to pull the second message.
type gatedMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int32
}
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
i := int(atomic.LoadInt32(&g.idx))
if i >= len(g.messages) {
// Block instead of returning EOF so the loop doesn't exit; we only
// care whether the loop ever reaches this second Recv at all.
select {}
}
msg := g.messages[i]
atomic.AddInt32(&g.idx, 1)
return msg, nil
}
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
func (g *gatedMappingStream) CloseSend() error { return nil }
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
func (g *gatedMappingStream) SendMsg(any) error { return nil }
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
// noopNotifier satisfies roundtrip's statusNotifier interface.
type noopNotifier struct{}
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
return nil
}
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
// compile time; none of those methods are called by this test.
type noopProxyClient struct {
proto.ProxyServiceClient
}
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
return &proto.SendStatusUpdateResponse{}, nil
}
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
// mapping receive loop processes batches strictly serially, so when applying
// one batch blocks (here: createClientEntry parked on a synchronous
// CreateProxyPeer call, exactly as observed in production), the loop never
// advances to Recv the next batch. Management can keep sending updates onto
// the stream with no error and no channel overflow, yet the proxy applies
// nothing further — it is stuck.
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
mgmt := &blockingMgmtClient{
entered: make(chan struct{}),
}
nb := roundtrip.NewNetBird(
context.Background(),
"proxy-test",
"proxy.example.com",
roundtrip.ClientConfig{},
logger,
noopNotifier{},
mgmt,
)
s := &Server{
Logger: logger,
netbird: nb,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// First batch: a CREATED mapping for a brand-new account. addMapping ->
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
// no routers/auth need wiring. The second batch exists only to detect
// whether the loop ever advances past the blocked first batch.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-1",
AccountId: "acct-1",
AuthToken: "token-1",
},
},
},
{
Mapping: []*proto.ProxyMapping{
{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "svc-2",
AccountId: "acct-2",
AuthToken: "token-2",
},
},
},
},
}
ctx, cancel := context.WithCancel(context.Background())
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
// avoiding any dependency on collaborators this test deliberately leaves
// nil. The deadlock is fully proven before this fires.
t.Cleanup(cancel)
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
}()
// The loop must reach the blocking apply for the first batch.
select {
case <-mgmt.entered:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
}
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
// single-threaded loop cannot advance. The second batch is never pulled,
// even though it is already available on the stream. Give it ample time.
// deliveredCount is atomic; syncDone is intentionally not read here because
// the loop goroutine owns it (reading it from the test would race).
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged in apply")
default:
// Still wedged, as expected.
}
}
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
// path observed in production: a mapping remove tears down the account's last
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
// threaded, a blocked remove wedges the loop: no further mapping updates of any
// kind (create/modify/remove) are applied, while management keeps sending them
// successfully (no send error, no channel-full). Matches the reported symptom:
// the last log line is a remove that stops a client, then silence.
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
logger := log.New()
logger.SetLevel(log.PanicLevel)
enteredRemove := make(chan struct{})
blockRemove := make(chan struct{})
var once sync.Once
s := &Server{
Logger: logger,
mgmtClient: noopProxyClient{},
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
// Stand in for netbird.RemovePeer -> client.Stop hanging on
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
// blocks; later removes return immediately so the recovery assertion
// can observe the loop advancing.
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
first := false
once.Do(func() {
first = true
close(enteredRemove)
})
if !first {
return nil
}
select {
case <-blockRemove:
case <-ctx.Done():
}
return nil
},
}
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
// that must never be applied while the remove is wedged.
stream := &gatedMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
},
},
{
Mapping: []*proto.ProxyMapping{
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
},
},
},
}
loopDone := make(chan struct{})
syncDone := false
go func() {
defer close(loopDone)
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
}()
select {
case <-enteredRemove:
case <-time.After(2 * time.Second):
t.Fatal("receive loop never reached the blocking remove for the first batch")
}
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
// syncDone is owned by the loop goroutine, so it is not read here.
time.Sleep(500 * time.Millisecond)
assert.Equal(t, int32(1), stream.deliveredCount(),
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
select {
case <-loopDone:
t.Fatal("receive loop returned while it should be wedged on the remove")
default:
}
// Unblock and confirm the wedge was solely the blocked remove: the loop
// then advances and consumes the next batch.
close(blockRemove)
assert.Eventually(t, func() bool {
return stream.deliveredCount() >= 2
}, 2*time.Second, 5*time.Millisecond,
"once the remove unblocks, the loop must advance and consume the next batch")
}

View File

@@ -73,7 +73,7 @@ func benchServerWithLatency(b *testing.B, createPeerDelay, statusDelay time.Dura
statusUpdateDelay: statusDelay,
}
nb := roundtrip.NewNetBird("bench-proxy", "bench.test",
nb := roundtrip.NewNetBird(b.Context(), "bench-proxy", "bench.test",
roundtrip.ClientConfig{MgmtAddr: "http://bench.test:9999"},
logger, nil, mgmtClient)

View File

@@ -24,6 +24,7 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
"github.com/pires/go-proxyproto"
prometheus2 "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -75,28 +76,30 @@ type portRouter struct {
}
type Server struct {
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
ctx context.Context
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
staticCertWatcher *certwatch.Watcher
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
meter *proxymetrics.Metrics
accessLog *accesslog.Logger
mainRouter *nbtcp.Router
mainPort uint16
udpMu sync.Mutex
udpRelays map[types.ServiceID]*udprelay.Relay
udpRelayWg sync.WaitGroup
portMu sync.RWMutex
portRouters map[uint16]*portRouter
svcPorts map[types.ServiceID][]uint16
lastMappings map[types.ServiceID]*proto.ProxyMapping
portRouterWg sync.WaitGroup
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
@@ -117,6 +120,9 @@ type Server struct {
// The mapping worker waits on this before processing updates.
routerReady chan struct{}
// removePeer defaults to netbird.RemovePeer; overridable in tests.
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
// inbound, when non-nil, manages per-account inbound listeners. Set by
// initPrivateInbound only when Private is true so the standalone
// proxy keeps its zero-overhead default path.
@@ -226,6 +232,10 @@ type Server struct {
// Zero means no cap (the proxy honors whatever management sends).
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
MaxSessionIdleTimeout time.Duration
// MappingBatchWatchdog bounds how long a single mapping batch may spend
// in processMappings before the receive loop reconnects to resync.
// Zero uses defaultMappingBatchWatchdog.
MappingBatchWatchdog time.Duration
}
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
@@ -281,7 +291,7 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID types.Ac
}
// inboundListenerProto resolves the per-account inbound listener state for
// the SendStatusUpdate payload. Returns nil when --private-inbound is off
// the SendStatusUpdate payload. Returns nil when --private is off
// or the account has no live listener so management treats the field as
// absent.
func (s *Server) inboundListenerProto(accountID types.AccountID) *proto.ProxyInboundListener {
@@ -528,10 +538,10 @@ func (s *Server) initManagementClient() error {
}
// initNetBirdClient builds the multi-tenant embedded NetBird client used
// for outbound RoundTripping and (when --private-inbound is on) per-account
// for outbound RoundTripping and (when --private is on) per-account
// inbound listeners.
func (s *Server) initNetBirdClient() {
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
s.netbird = roundtrip.NewNetBird(s.ctx, s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
@@ -606,7 +616,7 @@ func (s *Server) initDefaults() {
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
s.ID = fmt.Sprintf("netbird-proxy-%s", uuid.NewString())
}
// Fallback version option in case it is not set.
if s.Version == "" {
@@ -784,6 +794,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
}
go certWatcher.Watch(ctx)
s.staticCertWatcher = certWatcher
tlsConfig.GetCertificate = certWatcher.GetCertificate
return tlsConfig, nil
}
@@ -1171,24 +1182,30 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
s.healthChecker.SetManagementConnected(false)
}
connected := false
onConnected := func() { connected = true }
var streamErr error
if syncSupported {
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
if isSyncUnimplemented(streamErr) {
syncSupported = false
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
}
} else {
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
}
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
// Stream established — reset backoff so the next failure retries quickly.
bo.Reset()
// Reset backoff only when a stream actually connected, so immediate
// connect failures still back off instead of spinning.
if connected {
bo.Reset()
}
if streamErr == nil {
return fmt.Errorf("stream closed by server")
@@ -1220,7 +1237,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
}
}
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
connectTime := time.Now()
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
@@ -1233,6 +1250,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return fmt.Errorf("create mapping stream: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1241,7 +1259,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
}
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
connectTime := time.Now()
stream, err := client.SyncMappings(ctx)
if err != nil {
@@ -1262,6 +1280,7 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
return fmt.Errorf("send sync init: %w", err)
}
onConnected()
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
@@ -1306,7 +1325,9 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
@@ -1390,7 +1411,9 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
batchStart := time.Now()
s.Logger.Debug("Received mapping update, starting processing")
s.processMappings(ctx, msg.GetMapping())
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
return err
}
s.Logger.Debug("Processing mapping update completed")
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
}
@@ -1455,6 +1478,44 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
return c
}
const defaultMappingBatchWatchdog = 2 * time.Minute
// mappingBatchWatchdog returns the configured batch watchdog or the default.
func (s *Server) mappingBatchWatchdog() time.Duration {
if s.MappingBatchWatchdog > 0 {
return s.MappingBatchWatchdog
}
return defaultMappingBatchWatchdog
}
// processMappingsGuarded applies a batch under a watchdog, returning an error
// if processing exceeds the watchdog so the caller reconnects and resyncs
// instead of wedging silently.
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
batchCtx, cancel := context.WithCancel(ctx)
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
s.processMappings(batchCtx, mappings)
}()
watchdog := s.mappingBatchWatchdog()
timer := time.NewTimer(watchdog)
defer timer.Stop()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
for _, mapping := range mappings {
@@ -1565,6 +1626,8 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
var wildcardHit bool
if s.acme != nil {
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
} else {
wildcardHit = s.staticCertCovers(d)
}
httpRoute := nbtcp.Route{
Type: nbtcp.RouteHTTP,
@@ -1589,6 +1652,26 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
return nil
}
// staticCertCovers reports whether the static certificate loaded when ACME is
// disabled covers the given domain, making it certificate-ready immediately —
// the equivalent of a wildcard hit in the ACME path. Domains the certificate
// does not cover are logged: clients connecting to them will get TLS errors.
func (s *Server) staticCertCovers(d domain.Domain) bool {
if s.staticCertWatcher == nil {
return false
}
leaf := s.staticCertWatcher.Leaf()
if leaf == nil {
return false
}
name := d.PunycodeString()
if err := leaf.VerifyHostname(name); err != nil {
s.Logger.Warnf("static certificate (SANs %v) does not cover domain %q: %v", leaf.DNSNames, name, err)
return false
}
return true
}
// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port.
func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
svcID := types.ServiceID(mapping.GetId())
@@ -1950,7 +2033,11 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
accountID := types.AccountID(mapping.GetAccountId())
svcKey := s.serviceKeyForMapping(mapping)
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
removePeer := s.removePeer
if removePeer == nil {
removePeer = s.netbird.RemovePeer
}
if err := removePeer(ctx, accountID, svcKey); err != nil {
s.Logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": mapping.GetId(),

View File

@@ -64,7 +64,7 @@ func quietLifecycleLogger() *log.Logger {
}
func TestStopBeforeStartIsNoOp(t *testing.T) {
srv := New(Config{Logger: quietLifecycleLogger()})
srv := New(t.Context(), Config{Logger: quietLifecycleLogger()})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
@@ -77,7 +77,7 @@ func TestStopBeforeStartIsNoOp(t *testing.T) {
}
func TestStartFailsWithoutManagement(t *testing.T) {
srv := New(Config{
srv := New(t.Context(), Config{
Logger: quietLifecycleLogger(),
ListenAddr: "127.0.0.1:0",
ManagementAddress: "://broken-url",
@@ -137,7 +137,7 @@ func TestRecordRunErrPreservesFirstFailure(t *testing.T) {
}
func TestStopSkipsShutdownWhenNeverStarted(t *testing.T) {
srv := New(Config{Logger: quietLifecycleLogger()})
srv := New(t.Context(), Config{Logger: quietLifecycleLogger()})
ctx, cancel := context.WithCancel(context.Background())
cancel()

89
proxy/static_cert_test.go Normal file
View File

@@ -0,0 +1,89 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/shared/management/domain"
)
func generateCertWithSANs(t *testing.T, dnsNames []string) (certPEM, keyPEM []byte) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: dnsNames[0]},
DNSNames: dnsNames,
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return certPEM, keyPEM
}
func newStaticWatcher(t *testing.T, dnsNames []string) *certwatch.Watcher {
t.Helper()
dir := t.TempDir()
certPEM, keyPEM := generateCertWithSANs(t, dnsNames)
certPath := filepath.Join(dir, "tls.crt")
keyPath := filepath.Join(dir, "tls.key")
require.NoError(t, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600))
w, err := certwatch.NewWatcher(certPath, keyPath, quietLifecycleLogger())
require.NoError(t, err)
return w
}
func TestStaticCertCovers(t *testing.T) {
s := &Server{
Logger: quietLifecycleLogger(),
staticCertWatcher: newStaticWatcher(t, []string{"*.p.example.com", "exact.example.com"}),
}
cases := []struct {
domain string
covered bool
}{
{"svc.p.example.com", true},
{"exact.example.com", true},
{"a.b.p.example.com", false}, // wildcard does not span labels
{"p.example.com", false},
{"other.example.com", false},
}
for _, tc := range cases {
t.Run(tc.domain, func(t *testing.T) {
assert.Equal(t, tc.covered, s.staticCertCovers(domain.Domain(tc.domain)))
})
}
}
func TestStaticCertCoversNoWatcher(t *testing.T) {
s := &Server{Logger: quietLifecycleLogger()}
assert.False(t, s.staticCertCovers(domain.Domain("svc.p.example.com")))
}

View File

@@ -417,15 +417,30 @@ if type uname >/dev/null 2>&1; then
# Check the availability of a compatible package manager
if check_use_bin_variable; then
PACKAGE_MANAGER="bin"
elif [ -e /run/ostree-booted ]; then
if [ -x "$(command -v rpm-ostree)" ]; then
PACKAGE_MANAGER="rpm-ostree"
echo "The installation will be performed using rpm-ostree package manager"
elif [ -x "$(command -v bootc)" ]; then
echo "Detected bootc system without rpm-ostree." >&2
echo "NetBird cannot be installed via package manager on this system." >&2
echo "Options:" >&2
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
echo " 2. Rebuild your base image with rpm-ostree included" >&2
echo " 3. Bake NetBird into your Containerfile" >&2
exit 1
else
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
echo "NetBird cannot be installed automatically on this atomic system." >&2
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
exit 1
fi
elif [ -x "$(command -v apt-get)" ]; then
PACKAGE_MANAGER="apt"
echo "The installation will be performed using apt package manager"
elif [ -x "$(command -v dnf)" ]; then
PACKAGE_MANAGER="dnf"
echo "The installation will be performed using dnf package manager"
elif [ -x "$(command -v rpm-ostree)" ]; then
PACKAGE_MANAGER="rpm-ostree"
echo "The installation will be performed using rpm-ostree package manager"
elif [ -x "$(command -v yum)" ]; then
PACKAGE_MANAGER="yum"
echo "The installation will be performed using yum package manager"

View File

@@ -6,4 +6,5 @@ const (
RoleKey = "role"
UserIDKey = "userID"
PeerIDKey = "peerID"
UserAgentKey = "userAgent"
)

Some files were not shown because too many files have changed in this diff Show More