Compare commits

...

114 Commits

Author SHA1 Message Date
pascal
28303d7ded Merge branch 'main' into feature/check-cert-locker-before-acme 2026-03-10 14:25:50 +01:00
Pascal Fischer
5585adce18 [management] add activity events for domains (#5548)
* add activity events for domains

* fix test

* update activity codes

* update activity codes
2026-03-09 19:04:04 +01:00
Pascal Fischer
f884299823 [proxy] refactor metrics and add usage logs (#5533)
* **New Features**
  * Access logs now include bytes_upload and bytes_download (API and schemas updated, fields required).
  * Certificate issuance duration is now recorded as a metric.

* **Refactor**
  * Metrics switched from Prometheus client to OpenTelemetry-backed meters; health endpoint now exposes OpenMetrics via OTLP exporter.

* **Tests**
  * Metric tests updated to use OpenTelemetry Prometheus exporter and MeterProvider.
2026-03-09 18:45:45 +01:00
Maycon Santos
15aa6bae1b [client] Fix exit node menu not refreshing on Windows (#5553)
* [client] Fix exit node menu not refreshing on Windows

TrayOpenedCh is not implemented in the systray library on Windows,
so exit nodes were never refreshed after the initial connect. Combined
with the management sync not having populated routes yet when the
Connected status fires, this caused the exit node menu to remain empty
permanently after disconnect/reconnect cycles.
Add a background poller on Windows that refreshes exit nodes while
connected, with fast initial polling to catch routes from management
sync followed by a steady 10s interval. On macOS/Linux, TrayOpenedCh
continues to handle refreshes on each tray open.
Also fix a data race on connectClient assignment in the server's connect()
method and add nil checks in CleanState/DeleteState to prevent panics
when connectClient is nil.

* Remove unused exitNodeIDs

* Remove unused exitNodeState struct
2026-03-09 18:39:11 +01:00
Pascal Fischer
11eb725ac8 [management] only count login request duration for successful logins (#5545) 2026-03-09 14:56:46 +01:00
pascal
a19611d8e0 check the cert locker before starting the acme challenge 2026-03-09 13:40:41 +01:00
Pascal Fischer
30c02ab78c [management] use the cache for the pkce state (#5516) 2026-03-09 12:23:06 +01:00
Zoltan Papp
3acd86e346 [client] "reset connection" error on wake from sleep (#5522)
Capture engine reference before actCancel() in cleanupConnection().

After actCancel(), the connectWithRetryRuns goroutine sets engine to nil,
causing connectClient.Stop() to skip shutdown. This allows the goroutine
to set ErrResetConnection on the shared state after Down() clears it,
causing the next Up() to fail.
2026-03-09 10:25:51 +01:00
Pascal Fischer
5c20f13c48 [management] fix domain uniqueness (#5529) 2026-03-07 10:46:37 +01:00
Pascal Fischer
e6587b071d [management] use realip for proxy registration (#5525) 2026-03-06 16:11:44 +01:00
Maycon Santos
85451ab4cd [management] Add stable domain resolution for combined server (#5515)
The combined server was using the hostname from exposedAddress for both
singleAccountModeDomain and dnsDomain, causing fresh installs to get
the wrong domain and existing installs to break if the config changed.
 Add resolveDomains() to BaseServer that reads domain from the store:
  - Fresh install (0 accounts): uses "netbird.selfhosted" default
  - Existing install: reads persisted domain from the account in DB
  - Store errors: falls back to default safely

The combined server opts in via AutoResolveDomains flag, while the
 standalone management server is unaffected.
2026-03-06 08:43:46 +01:00
Pascal Fischer
a7f3ba03eb [management] aggregate grpc metrics by accountID (#5486) 2026-03-05 22:10:45 +01:00
Maycon Santos
4f0a3a77ad [management] Avoid breaking single acc mode when switching domains (#5511)
* **Bug Fixes**
  * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data.
  * Improved error handling when account data is unavailable with fallback to configured default domain.

* **Tests**
  * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data.
2026-03-05 14:30:31 +01:00
Maycon Santos
44655ca9b5 [misc] add PR title validation workflow (#5503) 2026-03-05 11:43:18 +01:00
Viktor Liu
e601278117 [management,proxy] Add per-target options to reverse proxy (#5501) 2026-03-05 10:03:26 +01:00
Maycon Santos
8e7b016be2 [management] Replace in-memory expose tracker with SQL-backed operations (#5494)
The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart.

Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column:

- Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists
- Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass
- Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission 
- Add composite index on (source, source_peer) for efficient queries
- Batch-limit and column-select the reaper query to avoid DB/GC spikes
- Filter out malformed rows with empty source_peer
2026-03-04 18:15:13 +01:00
Maycon Santos
9e01ea7aae [misc] Add ISSUE_TEMPLATE configuration file (#5500)
Add issue template config file  with support and troubleshooting links
2026-03-04 14:30:54 +01:00
hbzhost
cfc7ec8bb9 [client] Fix SSH JWT auth failure with Azure Entra ID iat backdating (#5471)
Increase DefaultJWTMaxTokenAge from 5 to 10 minutes to accommodate
identity providers like Azure Entra ID that backdate the iat claim
by up to 5 minutes, causing tokens to be immediately rejected.

Fixes #5449

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-04 14:11:14 +01:00
Misha Bragin
b3bbc0e5c6 Fix embedded IdP metrics to count local and generic OIDC users (#5498) 2026-03-04 12:34:11 +02:00
Pascal Fischer
d7c8e37ff4 [management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-03-03 18:39:46 +01:00
Zoltan Papp
05b66e73bc [client] Fix deadlock in route peer status watcher (#5489)
Wrap peerStateUpdate send in a nested select to prevent goroutine
blocking when the consumer has exited, which could fill the
subscription buffer and deadlock the Status mutex.
2026-03-03 13:50:46 +01:00
Jeremie Deray
01ceedac89 [client] Fix profile config directory permissions (#5457)
* fix user profile dir perm

* fix fileExists

* revert return var change

* fix anti-pattern
2026-03-03 13:48:51 +01:00
Misha Bragin
403babd433 [self-hosted] specify sql file location of auth, activity and main store (#5487) 2026-03-03 12:53:16 +02:00
Maycon Santos
47133031e5 [client] fix: client/Dockerfile to reduce vulnerabilities (#5217)
Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-03-03 08:44:08 +01:00
Pascal Fischer
82da606886 [management] Add explicit target delete on service removal (#5420) 2026-03-02 18:25:44 +01:00
Viktor Liu
bbe5ae2145 [client] Flush buffer immediately to support gprc (#5469) 2026-03-02 15:17:08 +01:00
Viktor Liu
0b21498b39 [client] Fix close of closed channel panic in ConnectClient retry loop (#5470) 2026-03-02 10:07:53 +01:00
Viktor Liu
0ca59535f1 [management] Add reverse proxy services REST client (#5454) 2026-02-28 13:04:58 +08:00
Misha Bragin
59c77d0658 [self-hosted] support embedded IDP postgres db (#5443)
* Add postgres config for embedded idp

Entire-Checkpoint: 9ace190c1067

* Rename idpStore to authStore

Entire-Checkpoint: 73a896c79614

* Fix review notes

Entire-Checkpoint: 6556783c0df3

* Don't accept pq port = 0

Entire-Checkpoint: 80d45e37782f

* Optimize configs

Entire-Checkpoint: 80d45e37782f

* Fix lint issues

Entire-Checkpoint: 3eec968003d1

* Fail fast on combined postgres config

Entire-Checkpoint: b17839d3d8c6

* Simplify management config method

Entire-Checkpoint: 0f083effa20e
2026-02-27 14:52:54 +01:00
shuuri-labs
333e045099 Lower socket auto-discovery log from Info to Debug (#5463)
The discovery message was printing on every CLI invocation, which is
noisy for users on distros using the systemd template.
2026-02-26 17:51:38 +01:00
Zoltan Papp
c2c4d9d336 [client] Fix Server mutex held across waitForUp in Up() (#5460)
Up() acquired s.mutex with a deferred unlock, then called waitForUp()
while still holding the lock. waitForUp() blocks for up to 50 seconds
waiting on clientRunningChan/clientGiveUpChan, starving all concurrent
gRPC calls that require the same mutex (Status, ListProfiles, etc.).

Replace the deferred unlock with explicit s.mutex.Unlock() on every
early-return path and immediately before waitForUp(), matching the
pattern already used by the clientRunning==true branch.
2026-02-26 16:47:02 +01:00
Bethuel Mmbaga
9a6a72e88e [management] Fix user update permission validation (#5441) 2026-02-24 22:47:41 +03:00
Bethuel Mmbaga
afe6d9fca4 [management] Prevent deletion of groups linked to flow groups (#5439) 2026-02-24 21:19:43 +03:00
shuuri-labs
ef82905526 [client] Add non default socket file discovery (#5425)
- Automatic Unix daemon address discovery: if the default socket is missing, the client can find and use a single available socket.
- Client startup now resolves daemon addresses more robustly while preserving non-Unix behavior.
2026-02-24 17:02:06 +01:00
Zoltan Papp
d18747e846 [client] Exclude Flow domain from caching to prevent TLS failures (#5433)
* Exclude Flow domain from caching to prevent TLS failures due to stale records.

* Fix test
2026-02-24 16:48:38 +01:00
Maycon Santos
f341d69314 [management] Add custom domain counts and service metrics to self-hosted metrics (#5414) 2026-02-24 15:21:14 +01:00
Maycon Santos
327142837c [management] Refactor expose feature: move business logic from gRPC to manager (#5435)
Consolidate all expose business logic (validation, permission checks, TTL tracking, reaping) into the manager layer, making the gRPC layer a pure transport adapter that only handles proto conversion and authentication.

- Add ExposeServiceRequest/ExposeServiceResponse domain types with validation in the reverseproxy package
- Move expose tracker (TTL tracking, reaping, per-peer limits) from gRPC server into manager/expose_tracker.go
- Internalize tracking in CreateServiceFromPeer, RenewServiceFromPeer, and new StopServiceFromPeer so callers don't manage tracker state
- Untrack ephemeral services in DeleteService/DeleteAllServices to keep tracker in sync when services are deleted via API
- Simplify gRPC expose handlers to parse, auth, convert, delegate
- Remove tracker methods from Manager interface (internal detail)
2026-02-24 15:09:30 +01:00
Zoltan Papp
f8c0321aee [client] Simplify DNS logging by removing domain list from log output (#5396) 2026-02-24 10:35:45 +01:00
Zoltan Papp
89115ff76a [client] skip UAPI listener in netstack mode (#5397)
In netstack (proxy) mode, the process lacks permission to create
/var/run/wireguard, making the UAPI listener unnecessary and causing
a misleading error log. Introduce NewUSPConfigurerNoUAPI and use it
for the netstack device to avoid attempting to open the UAPI socket
entirely. Also consolidate UAPI error logging to a single call site.
2026-02-24 10:35:23 +01:00
Maycon Santos
63c83aa8d2 [client,management] Feature/client service expose (#5411)
CLI: new expose command to publish a local port with flags for PIN, password, user groups, custom domain, name prefix and protocol (HTTP default).
Management/API: create/renew/stop expose sessions (streamed status), automatic naming/domain, TTL renewals, background expiration, new management RPCs and client methods.
UI/API: account settings now include peer_expose_enabled and peer_expose_groups; new activity codes for peer expose events.
2026-02-24 10:02:16 +01:00
Zoltan Papp
37f025c966 Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state
transitions. The mutex now covers the full duration of each handler
including the status check, the Up/Down call, and the flag update.

Note: if Up or Down commands are triggered in parallel with sleep/wake
events, the overall ordering of up/down/sleep/wake operations is still
not guaranteed beyond what the mutex provides within the handler itself.
2026-02-24 10:00:33 +01:00
Zoltan Papp
4a54f0d670 [Client] Remove connection semaphore (#5419)
* [Client] Remove connection semaphore

Remove the semaphore and the initial random sleep time (300ms) from the connectivity logic to speed up the initial connection time.

Note: Implement limiter logic that can prioritize router peers and keep the fast connection option for the first few peers.

* Remove unused function
2026-02-23 20:58:53 +01:00
Zoltan Papp
98890a29e3 [client] fix busy-loop in network monitor routing socket on macOS/BSD (#5424)
* [client] fix busy-loop in network monitor routing socket on macOS/BSD

After system wakeup, the AF_ROUTE socket created by Go's unix.Socket()
is non-blocking, causing unix.Read to return EAGAIN immediately and spin
at 100% CPU filling the log with thousands of warnings per second.

Replace the tight read loop with a unix.Select call that blocks until
the fd is readable, checking ctx cancellation on each 1-second timeout.
Fatal errors (EBADF, EINVAL) now return an error instead of looping.

* [client] add fd range validation in waitReadable to prevent out-of-bound errors
2026-02-23 20:58:27 +01:00
Pascal Fischer
9d123ec059 [proxy] add pre-shared key support (#5377) 2026-02-23 16:31:29 +01:00
Pascal Fischer
5d171f181a [proxy] Send proxy updates on account delete (#5375) 2026-02-23 16:08:28 +01:00
Vlad
22f878b3b7 [management] network map components assembling (#5193) 2026-02-23 15:34:35 +01:00
Misha Bragin
44ef1a18dd [self-hosted] add Embedded IdP metrics (#5407) 2026-02-22 11:58:35 +02:00
Misha Bragin
2b98dc4e52 [self-hosted] Support activity store engine in the combined server (#5406) 2026-02-22 11:58:17 +02:00
Zoltan Papp
2a26cb4567 [client] stop upstream retry loop immediately on context cancellation (#5403)
stop upstream retry loop immediately on context cancellation
2026-02-20 14:44:14 +01:00
Pascal Fischer
5ca1b64328 [management] access log sorting (#5378) 2026-02-20 00:11:55 +01:00
Pascal Fischer
36752a8cbb [proxy] add access log cleanup (#5376) 2026-02-20 00:11:28 +01:00
Maycon Santos
f117fc7509 [client] Log lock acquisition time in receive message handling (#5393)
* Log lock acquisition time in receive message handling

* use offerAnswer.SessionID for session id
2026-02-19 19:18:47 +01:00
Zoltan Papp
fc6b93ae59 [ios] Ensure route settlement on iOS before handling DNS responses (#5360)
* Ensure route settlement on iOS before handling DNS responses to prevent bypassing the tunnel.

* add more logs

* rollback debug changes

* rollback  changes

* [client] Improve logging and add comments for iOS route settlement logic

- Switch iOS route settlement log level from Debug to Trace for finer control.
- Add clarifying comments for `waitForRouteSettlement` on non-iOS platforms.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-02-19 18:53:10 +01:00
Vlad
564fa4ab04 [management] fix possible race condition on user role change (#5395) 2026-02-19 18:34:28 +01:00
Maycon Santos
a6db88fbd2 [misc] Update timestamp format with milliseconds (#5387)
* Update timestamp format with milliseconds

* fix tests
2026-02-19 11:23:42 +01:00
Misha Bragin
4b5294e596 [self-hosted] remove unused config example (#5383) 2026-02-19 08:14:11 +01:00
shuuri-labs
a322dce42a [self-hosted] create migration script for pre v0.65.0 to post v0.65.0 (combined) (#5350) 2026-02-18 20:59:55 +01:00
Maycon Santos
d1ead2265b [client] Batch macOS DNS domains to avoid truncation (#5368)
* [client] Batch macOS DNS domains across multiple scutil keys to avoid truncation

scutil has undocumented limits: 99-element cap on d.add arrays and ~2048
  byte value buffer for SupplementalMatchDomains. Users with 60+ domains
  hit silent domain loss. This applies the same batching approach used on
  Windows (nrptMaxDomainsPerRule=50), splitting domains into indexed
  resolver keys (NetBird-Match-0, NetBird-Match-1, etc.) with 50-element
  and 1500-byte limits per key.

* check for all keys on getRemovableKeysWithDefaults

* use multi error
2026-02-18 19:14:09 +01:00
Maycon Santos
bbca74476e [management] docker login on management tests (#5323) 2026-02-18 16:11:17 +01:00
Zoltan Papp
318cf59d66 [relay] reduce QUIC initial packet size to 1280 (IPv6 min MTU) (#5374)
* [relay] reduce QUIC initial packet size to 1280 (IPv6 min MTU)

* adjust QUIC initial packet size to 1232 based on RFC 9000 §14
2026-02-18 10:58:14 +01:00
Pascal Fischer
e9b2a6e808 [managment] add flag to disable the old legacy grpc endpoint (#5372) 2026-02-17 19:53:14 +01:00
Zoltan Papp
2dbdb5c1a7 [client] Refactor WG endpoint setup with role-based proxy activation (#5277)
* Refactor WG endpoint setup with role-based proxy activation

For relay connections, the controller (initiator) now activates the
wgProxy before configuring the WG endpoint, while the non-controller
(responder) configures the endpoint first with a delayed update, then
activates the proxy after. This prevents the responder from sending
traffic through the proxy before WireGuard is ready to receive it,
avoiding handshake congestion when both sides try to initiate
simultaneously.

For ICE connections, pass hasRelayBackup as the setEndpointNow flag
so the responder sets the endpoint immediately when a relay fallback
exists (avoiding the delayed update path since relay is already
available as backup).

On ICE disconnect with relay fallback, remove the duplicate
wgProxyRelay.Work() calls — the relay proxy is already active from
initial setup, so re-activating it is unnecessary.

In EndpointUpdater, split ConfigureWGEndpoint into explicit
configureAsInitiator and configureAsResponder paths, and add the
setEndpointNow parameter to let the caller control whether the
responder applies the endpoint immediately or defers it. Add unused
SwitchWGEndpoint and RemoveEndpointAddress methods. Remove the
wgConfigWorkaround sleep from the relay setup path.

* Fix redundant wgProxyRelay.Work() call during relay fallback setup

* Simplify WireGuard endpoint configuration by removing unused parameters and redundant logic
2026-02-17 19:28:26 +01:00
Pascal Fischer
2cdab6d7b7 [proxy] remove unused oidc config flags (#5369) 2026-02-17 18:04:30 +01:00
Diego Noguês
e49c0e8862 [infrastructure] Proxy infra changes (#5365)
* chore: remove docker extra_hosts settings

* chore: remove unnecessary envc from proxy.env
2026-02-17 17:37:44 +01:00
Misha Bragin
e7c84d0ead Start Management if external IdP is down (#5367)
Set ContinueOnConnectorFailure: true in the embedded Dex config so that the Management server starts successfully even when an external IdP connector is unreachable at boot time.
2026-02-17 16:08:41 +01:00
Zoltan Papp
1c934cca64 Ignore false lint alert (#5370) 2026-02-17 16:07:35 +01:00
Vlad
4aff4a6424 [management] fix utc difference on last seen status for a peer (#5348) 2026-02-17 13:29:32 +01:00
Zoltan Papp
1bd7190954 [proxy] Support WebSocket (#5312)
* Fix WebSocket support by implementing Hijacker interface

Add responsewriter.PassthroughWriter to preserve optional HTTP interfaces
(Hijacker, Flusher, Pusher) when wrapping http.ResponseWriter in middleware.

Without this delegation:
 - WebSocket connections fail (can't hijack the connection)
 - Streaming breaks (can't flush buffers)
 - HTTP/2 push doesn't work

* Add HijackTracker to manage hijacked connections during graceful shutdown

* Refactor HijackTracker to use middleware for tracking hijacked connections

* Refactor server handler chain setup for improved readability and maintainability
2026-02-17 12:53:34 +01:00
Viktor Liu
0146e39714 Add listener side proxy protocol support and enable it in traefik (#5332)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-02-16 23:40:10 +01:00
Zoltan Papp
baed6e46ec Reset WireGuard endpoint on ICE session change during relay fallback (#5283)
When an ICE connection disconnects and falls back to relay, reset the
WireGuard endpoint and handshake watcher if the remote peer's ICE session
has changed. This ensures the controller re-establishes a fresh WireGuard
handshake rather than waiting on a stale endpoint from the previous session.
2026-02-16 20:59:29 +01:00
Maycon Santos
0d1ffba75f [misc] add additional cname example (#5341) 2026-02-16 13:30:58 +01:00
Diego Romar
1024d45698 [mobile] Export lazy connection environment variables for mobile clients (#5310)
* [client] Export lazy connection env vars

Both for Android and iOS

* [client] Separate comments
2026-02-16 09:04:45 -03:00
Zoltan Papp
e5d4947d60 [client] Optimize Windows DNS performance with domain batching and batch mode (#5264)
* Optimize Windows DNS performance with domain batching and batch mode

Implement two-layer optimization to reduce Windows NRPT registry operations:

1. Domain Batching (host_windows.go):
  - Batch domains per NRPT
  - Reduces NRPT rules by ~97% (e.g., 184 domains: 184 rules → 4 rules)
  - Modified addDNSMatchPolicy() to create batched NRPT entries
  - Added comprehensive tests in host_windows_test.go

2. Batch Mode (server.go):
  - Added BeginBatch/EndBatch methods to defer DNS updates
  - Modified RegisterHandler/DeregisterHandler to skip applyHostConfig in batch mode
  - Protected all applyHostConfig() calls with batch mode checks
  - Updated route manager to wrap route operations with batch calls

* Update tests

* Fix log line

* Fix NRPT rule index to ensure cleanup covers partially created rules

* Ensure NRPT entry count updates even on errors to improve cleanup reliability

* Switch DNS batch mode logging from Info to Debug level

* Fix batch mode to not suppress critical DNS config updates

Batch mode should only defer applyHostConfig() for RegisterHandler/
DeregisterHandler operations. Management updates and upstream nameserver
failures (deactivate/reactivate callbacks) need immediate DNS config
updates regardless of batch mode to ensure timely failover.

Without this fix, if a nameserver goes down during a route update,
the system DNS config won't be updated until EndBatch(), potentially
delaying failover by several seconds.

Or if you prefer a shorter version:

Fix batch mode to allow immediate DNS updates for critical paths

Batch mode now only affects RegisterHandler/DeregisterHandler.
Management updates and nameserver failures always trigger immediate
DNS config updates to ensure timely failover.

* Add DNS batch cancellation to rollback partial changes on errors

Introduces CancelBatch() method to the DNS server interface to handle error
scenarios during batch operations. When route updates fail partway through, the DNS
server can now discard accumulated changes instead of applying partial state. This
prevents leaving the DNS configuration in an inconsistent state when route manager
operations encounter errors.

The changes add error-aware batch handling to prevent partial DNS configuration
updates when route operations fail, which improves system reliability.
2026-02-15 22:10:26 +01:00
Maycon Santos
cb9b39b950 [misc] add extra proxy domain instructions (#5328)
improve proxy domain instructions
expose wireguard port
2026-02-15 12:51:46 +01:00
Bethuel Mmbaga
68c481fa44 [management] Move service reload outside transaction in account settings update (#5325)
Bug Fixes

Network and DNS updates now defer service and reverse-proxy reloads until after account updates complete, preventing inconsistent proxy state and race conditions.
Chores

Removed automatic peer/broadcast updates immediately following bulk service reloads.
Tests

Added a test ensuring network-range changes complete without deadlock.
2026-02-14 20:27:15 +01:00
Misha Bragin
01a9cd4651 [misc] Fix reverse proxy getting started messaging (#5317)
* Fix reverse proxy getting started messaging

* Fix reverse proxy getting started messaging
2026-02-14 16:34:04 +01:00
Pascal Fischer
f53155562f [management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
2026-02-13 19:37:43 +01:00
Zoltan Papp
edce11b34d [client] Refactor/relay conn container (#5271)
* Fix race condition and ensure correct message ordering in
connection establishment

Reorder operations in OpenConn to register the connection before
waiting for peer availability. This ensures:

- Connection is ready to receive messages before peer subscription
completes
- Transport messages and onconnected events maintain proper ordering
- No messages are lost during the connection establishment window
- Concurrent OpenConn calls cannot create duplicate connections

If peer availability check fails, the pre-registered connection is
properly cleaned up.

* Handle service shutdown during relay connection initialization

Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors.

* Refactor relay client Conn/connContainer ownership and decouple Conn from Client

Conn previously held a direct *Client pointer and called client methods
(writeTo, closeConn, LocalAddr) directly, creating a tight bidirectional
coupling. The message channel was also created externally in OpenConn and
shared between Conn and connContainer with unclear ownership.

Now connContainer fully owns the lifecycle of both the channel and the
Conn it wraps:
- connContainer creates the channel (sized by connChannelSize const)
  and the Conn internally via newConnContainer
- connContainer feeds messages into the channel (writeMsg), closes and
  drains it on shutdown (close)
- Conn reads from the channel (Read) but never closes it

Conn is decoupled from *Client by replacing the *Client field with
three function closures (writeFn, closeFn, localAddrFn) that are wired
by newConnContainer at construction time. Write, Close, and LocalAddr
delegate to these closures. This removes the direct dependency while
keeping the identity-check logic: writeTo and closeConn now compare
connContainer pointers instead of Conn pointers to verify the caller
is the current active connection for that peer.
2026-02-13 15:48:08 +01:00
Zoltan Papp
841b2d26c6 Add early message buffer for relay client (#5282)
Add early message buffer to capture transport messages
arriving before OpenConn completes, ensuring correct
message ordering and no dropped messages.
2026-02-13 15:41:26 +01:00
Bethuel Mmbaga
d3eeb6d8ee [misc] Add cloud api spec to public open api with rest client (#5222) 2026-02-13 15:08:47 +03:00
Bethuel Mmbaga
7ebf37ef20 [management] Enforce access control on accessible peers (#5301) 2026-02-13 12:46:43 +03:00
Misha Bragin
64b849c801 [self-hosted] add netbird server (#5232)
* Unified NetBird combined server (Management, Signal, Relay, STUN) as a single executable with richer YAML configuration, validation, and defaults.
  * Official Dockerfile/image for single-container deployment.
  * Optional in-process profiling endpoint for diagnostics.
  * Multiplexing to route HTTP/gRPC/WebSocket traffic via one port; runtime hooks to inject custom handlers.
* **Chores**
  * Updated deployment scripts, compose files, and reverse-proxy templates to target the combined server; added example configs and getting-started updates.
2026-02-12 19:24:43 +01:00
Maycon Santos
69d4b5d821 [misc] Update sign pipeline version (#5296) 2026-02-12 11:31:49 +01:00
Viktor Liu
3dfa97dcbd [client] Fix stale entries in nftables with no handle (#5272) 2026-02-12 09:15:57 +01:00
Viktor Liu
1ddc9ce2bf [client] Fix nil pointer panic in device and engine code (#5287) 2026-02-12 09:15:42 +01:00
Maycon Santos
2de1949018 [client] Check if login is required on foreground mode (#5295) 2026-02-11 21:42:36 +01:00
Vlad
fc88399c23 [management] fixed ischild check (#5279) 2026-02-10 20:31:15 +03:00
Zoltan Papp
6981fdce7e [client] Fix race condition and ensure correct message ordering in Relay (#5265)
* Fix race condition and ensure correct message ordering in
connection establishment

Reorder operations in OpenConn to register the connection before
waiting for peer availability. This ensures:

- Connection is ready to receive messages before peer subscription
completes
- Transport messages and onconnected events maintain proper ordering
- No messages are lost during the connection establishment window
- Concurrent OpenConn calls cannot create duplicate connections

If peer availability check fails, the pre-registered connection is
properly cleaned up.

* Handle service shutdown during relay connection initialization

Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors.
2026-02-09 11:34:24 +01:00
Viktor Liu
08403f64aa [client] Add env var to skip DNS probing (#5270) 2026-02-09 11:09:11 +01:00
Viktor Liu
391221a986 [client] Fix uspfilter duplicate firewall rules (#5269) 2026-02-09 10:14:02 +01:00
Zoltan Papp
7bc85107eb Adds timing measurement to handleSync to help diagnose sync performance issues (#5228) 2026-02-06 19:50:48 +01:00
Zoltan Papp
3be16d19a0 [management] Feature/grpc debounce msgtype (#5239)
* Add gRPC update debouncing mechanism

Implements backpressure handling for peer network map updates to
efficiently handle rapid changes. First update is sent immediately,
subsequent rapid updates are coalesced, ensuring only the latest
update is sent after a 1-second quiet period.

* Enhance unit test to verify peer count synchronization with debouncing and timeout handling

* Debounce based on type

* Refactor test to validate timer restart after pending update dispatch

* Simplify timer reset for Go 1.23+ automatic channel draining

Remove manual channel drain in resetTimer() since Go 1.23+ automatically
drains the timer channel when Stop() returns false, making the
select-case pattern unnecessary.
2026-02-06 19:47:38 +01:00
Vlad
af8f730bda [management] check stream start time for connecting peer (#5267) 2026-02-06 18:00:43 +01:00
eyJhb
c3f176f348 [client] Fix wrong URL being logged for DefaultAdminURL (#5252)
- DefaultManagementURL was being logged instead of DefaultAdminURL
2026-02-06 11:23:36 +01:00
Viktor Liu
0119f3e9f4 [client] Fix netstack detection and add wireguard port option (#5251)
- Add WireguardPort option to embed.Options for custom port configuration
- Fix KernelInterface detection to account for netstack mode
- Skip SSH config updates when running in netstack mode
- Skip interface removal wait when running in netstack mode
- Use BindListener for netstack to avoid port conflicts on same host
2026-02-06 10:03:01 +01:00
Viktor Liu
1b96648d4d [client] Always log dns forwader responses (#5262) 2026-02-05 14:34:35 +01:00
Zoltan Papp
d2f9653cea Fix nil pointer panic in ICE agent during sleep/wake cycles (#5261)
Add defensive nil checks in ThreadSafeAgent.Close() to prevent panic
when agent field is nil. This can occur during Windows suspend/resume
when network interfaces are disrupted or the pion/ice library returns
nil without error.

Also capture agent pointer in local variable before goroutine execution
to prevent race conditions.

Fixes service crashes on laptop wake-up.
2026-02-05 12:06:28 +01:00
Zoltan Papp
194a986926 Cache the result of wgInterface.ToInterface() using sync.Once (#5256)
Avoid repeated conversions during route setup. The toInterface helper ensures
the conversion happens only once regardless of how many routes are added
or removed.
2026-02-04 22:22:37 +01:00
Viktor Liu
f7732557fa [client] Add missing bsd flags in debug bundle (#5254) 2026-02-04 18:07:27 +01:00
Vlad
d488f58311 [management] fix set disconnected status for connected peer (#5247) 2026-02-04 11:44:46 +01:00
Pascal Fischer
6fdc00ff41 [management] adding account id validation to accessible peers handler (#5246) 2026-02-03 17:30:02 +01:00
Misha Bragin
b20d484972 [docs] Add selfhosting video (#5235) 2026-02-01 16:06:36 +01:00
Vlad
8931293343 [management] run cancelPeerRoutinesWithoutLock in sync (#5234) 2026-02-01 15:44:27 +01:00
Vlad
7b830d8f72 disable sync lim (#5233) 2026-02-01 14:37:00 +01:00
Misha Bragin
3a0cf230a1 Disable local users for a smooth single-idp mode (#5226)
Add LocalAuthDisabled option to embedded IdP configuration

This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external
identity providers (Google, OIDC, etc.).

This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page.

Key changes:

Added LocalAuthDisabled field to EmbeddedIdPConfig
Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth
Validation prevents disabling local auth if no external connectors are configured
Existing local users are preserved when disabled and can login again when re-enabled
Operations are idempotent (disabling already disabled is a no-op)
2026-02-01 14:26:22 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
Zoltan Papp
5333e55a81 Fix WG watcher missing initial handshake (#5213)
Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp.

Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup.
2026-01-29 16:58:10 +01:00
Viktor Liu
81c11df103 [management] Streamline domain validation (#5211) 2026-01-29 13:51:44 +01:00
Viktor Liu
f74bc48d16 [Client] Stop NetBird on firewall init failure (#5208) 2026-01-29 11:05:06 +01:00
Vlad
0169e4540f [management] fix skip of ephemeral peers on deletion (#5206) 2026-01-29 10:58:45 +01:00
Vlad
cead3f38ee [management] fix ephemeral peers being not removed (#5203) 2026-01-28 18:24:12 +01:00
Zoltan Papp
b55262d4a2 [client] Refactor/optimise raw socket headers (#5174)
Pre-create and reuse packet headers to eliminate per-packet allocations.
2026-01-28 15:06:59 +01:00
Zoltan Papp
2248ff392f Remove redundant square bracket trimming in USP endpoint parsing (#5197) 2026-01-27 20:10:59 +01:00
433 changed files with 67437 additions and 2722 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

14
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,14 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies - name: Check for problematic license dependencies
run: | run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..." echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo "" echo ""
# Find all directories except the problematic ones and system dirs # Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do while IFS= read -r dir; do
echo "=== Checking $dir ===" echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files # Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:" echo "❌ Found problematic dependencies:"
echo "$RESULTS" echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else else
echo "✓ No problematic dependencies found" echo "✓ No problematic dependencies found"
fi fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort) done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
echo "" echo ""
if [ $FOUND_ISSUES -eq 1 ]; then if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages" echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1 exit 1
else else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath") IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay # Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1) BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER" echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -46,6 +46,5 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/... time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/... time go test -timeout 1m -failfast ./version/...

View File

@@ -97,6 +97,16 @@ jobs:
working-directory: relay working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
- name: Build combined
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 go build .
- name: Build combined 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
test: test:
name: "Client / Unit" name: "Client / Unit"
needs: [build-cache] needs: [build-cache]
@@ -144,7 +154,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
test_client_on_docker: test_client_on_docker:
name: "Client (Docker) / Unit" name: "Client (Docker) / Unit"
@@ -204,7 +214,7 @@ jobs:
sh -c ' \ sh -c ' \
apk update; apk add --no-cache \ apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server) go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
' '
test_relay: test_relay:
@@ -261,6 +271,53 @@ jobs:
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/... -timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"
needs: [build-cache] needs: [build-cache]
@@ -352,12 +409,19 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Login to Docker hub - name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v1 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image - name: download mysql image
if: matrix.store == 'mysql' if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8 run: docker pull mlsmaycon/warmed-mysql:8
@@ -440,15 +504,18 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Login to Docker hub - name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v1 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image - name: docker login for root user
if: matrix.store == 'mysql' if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
run: docker pull mlsmaycon/warmed-mysql:8 env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: Test - name: Test
run: | run: |
@@ -529,15 +596,18 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Login to Docker hub - name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v1 uses: docker/login-action@v3
with: with:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image - name: docker login for root user
if: matrix.store == 'mysql' if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
run: docker pull mlsmaycon/warmed-mysql:8 env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: Test - name: Test
run: | run: |

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,8 +19,8 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
skip: go.mod,go.sum skip: go.mod,go.sum,**/proxy/web/**
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false

51
.github/workflows/pr-title-check.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.0" SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -160,7 +160,7 @@ jobs:
username: ${{ secrets.DOCKER_USER }} username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry - name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
registry: ghcr.io registry: ghcr.io
@@ -176,6 +176,7 @@ jobs:
- name: Generate windows syso arm64 - name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
@@ -185,6 +186,19 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: Tag and push PR images (amd64 only)
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
run: |
PR_TAG="pr-${{ github.event.pull_request.number }}"
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
IMG_NAME="${SRC%%:*}"
DST="${IMG_NAME}:${PR_TAG}"
echo "Tagging ${SRC} -> ${DST}"
docker tag "$SRC" "$DST"
docker push "$DST"
done
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:

1
.gitignore vendored
View File

@@ -2,6 +2,7 @@
.run .run
*.iml *.iml
dist/ dist/
!proxy/web/dist/
bin/ bin/
.env .env
conf.json conf.json

View File

@@ -106,6 +106,26 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload - id: netbird-upload
dir: upload-server dir: upload-server
env: [CGO_ENABLED=0] env: [CGO_ENABLED=0]
@@ -120,6 +140,20 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-proxy
dir: proxy/cmd/proxy
env: [CGO_ENABLED=0]
binary: netbird-proxy
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries: universal_binaries:
- id: netbird - id: netbird
@@ -520,6 +554,104 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}" - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests: docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }} - name_template: netbirdio/netbird:{{ .Version }}
image_templates: image_templates:
@@ -598,6 +730,18 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm - netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64 - netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }} - name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates: image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8 - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -675,6 +819,43 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8 - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm - ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64 - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews: brews:
- ids: - ids:
- default - default

View File

@@ -1,4 +1,4 @@
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/. This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory. Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
BSD 3-Clause License BSD 3-Clause License

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video) ### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) [![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features ### Key features

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.2 FROM alpine:3.23.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

View File

@@ -1,10 +1,19 @@
package android package android
import "github.com/netbirdio/netbird/client/internal/peer" import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
var ( var (
// EnvKeyNBForceRelay Exported for Android java client // EnvKeyNBForceRelay Exported for Android java client to force relay connections
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
) )
// EnvList wraps a Go map for export to Java // EnvList wraps a Go map for export to Java

194
client/cmd/expose.go Normal file
View File

@@ -0,0 +1,194 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: "netbird expose --with-password safe-pass 8080",
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
})
if err != nil {
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
switch strings.ToLower(exposeProtocol) {
case "http":
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case "https":
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
}
switch e := event.Event.(type) {
case *proto.ExposeServiceEvent_Ready:
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Port: %d\n", port)
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
return nil
default:
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
} }
defer authClient.Close() defer authClient.Close()
needsLogin := false needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
err, isAuthError := authClient.Login(ctx, "", "") return fmt.Errorf("check login required: %v", err)
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
} }
jwtToken := "" jwtToken := ""

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
) )
@@ -80,6 +81,15 @@ var (
Short: "", Short: "",
Long: "", Long: "",
SilenceUsage: true, SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
} }
) )
@@ -144,6 +154,7 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd) rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd) rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) { func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil return conn, nil
} }
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -31,6 +31,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized") ErrConfigNotInitialized = errors.New("config not initialized")
) )
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// Client manages a netbird embedded client instance. // Client manages a netbird embedded client instance.
type Client struct { type Client struct {
deviceName string deviceName string
@@ -69,6 +77,10 @@ type Options struct {
StatePath string StatePath string
// DisableClientRoutes disables the client routes // DisableClientRoutes disables the client routes
DisableClientRoutes bool DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
} }
// validateCredentials checks that exactly one credential type is provided // validateCredentials checks that exactly one credential type is provided
@@ -137,6 +149,8 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey, PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t, DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
} }
if opts.ConfigPath != "" { if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input) config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -156,6 +170,7 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey, setupKey: opts.SetupKey,
jwtToken: opts.JWTToken, jwtToken: opts.JWTToken,
config: config, config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil }, nil
} }
@@ -177,6 +192,7 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil { if err != nil {
return fmt.Errorf("create auth client: %w", err) return fmt.Errorf("create auth client: %w", err)
@@ -186,10 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err) return fmt.Errorf("login: %w", err)
} }
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true) client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
@@ -342,14 +355,9 @@ func (c *Client) NewHTTPClient() *http.Client {
// Status returns the current status of the client. // Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) { func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock() c.mu.Lock()
recorder := c.recorder
connect := c.connect connect := c.connect
c.mu.Unlock() c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil { if connect != nil {
engine := connect.Engine() engine := connect.Engine()
if engine != nil { if engine != nil {
@@ -357,7 +365,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
} }
} }
return recorder.GetFullStatus(), nil return c.recorder.GetFullStatus(), nil
} }
// GetLatestSyncResponse returns the latest sync response from the management server. // GetLatestSyncResponse returns the latest sync response from the management server.

View File

@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
} }
if nftRule.Handle == 0 { if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey) log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
} }
if err := r.deleteNftRule(nftRule, ruleKey); err != nil { if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err) return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
} }
return nil return nil
} }
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue // addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error { func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { rule, exists := r.rules[ruleKey]
if err := r.conn.DelRule(rule); err != nil { if !exists {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return nil
} }
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil { if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err) log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
} }
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
} }
return nil return nil
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
var merr *multierror.Error
if pair.Masquerade { if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err) merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
} }
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err) merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
} }
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err) merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
} }
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
} }
return nil return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) removeNatRule(pair firewall.RouterPair) error { func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { rule, exists := r.rules[ruleKey]
if err := r.conn.DelRule(rule); err != nil { if !exists {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("prerouting rule %s not found", ruleKey) log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
} }
return nil return nil
} }
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid // refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// duplicates and to get missing attributes that we don't have when adding new rules // (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) refreshRulesMap() error { func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains { for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain) rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil { if err != nil {
return fmt.Errorf("list rules: %w", err) merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
} }
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule newRules[string(rule.UserData)] = rule
} }
} }
} }
return nil r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
} }
var merr *multierror.Error var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists { if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil { if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
} }
} }
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists { if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil { if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
} }
} }
if err := r.conn.Flush(); err != nil { if needsFlush {
merr = multierror.Append(merr, fmt.Errorf(flushError, err)) if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
} }
if merr == nil { if merr == nil {
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists { rule, exists := r.rules[ruleID]
if err := r.conn.DelRule(rule); err != nil { if !exists {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) return nil
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
} }
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil return nil
} }

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
) )
const ( const (
@@ -719,3 +720,137 @@ func deleteWorkTable() {
} }
} }
} }
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -3,12 +3,6 @@
package uspfilter package uspfilter
import ( import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.resetState()
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager) return m.nativeFirewall.Close(stateManager)

View File

@@ -1,12 +1,9 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.resetState()
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {
return nil return nil

View File

@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load() return t.tombstone.Load()
} }
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion // SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() { func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true) t.tombstone.Store(true)
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if exists { if exists && !conn.IsSupersededBy(flags) {
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true return key, uint16(conn.DNATOrigPort.Load()), true
} }
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists || conn.IsTombstone() { if !exists || conn.IsSupersededBy(flags) {
return false return false
} }

View File

@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
}) })
} }
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) { func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing // Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond shortTimeout := 100 * time.Millisecond

View File

@@ -1,6 +1,7 @@
package uspfilter package uspfilter
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -12,11 +13,13 @@ import (
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -24,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -89,6 +93,7 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger, flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{}, portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}), netstackServices: make(map[serviceKey]struct{}),
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
ruleID := uuid.New().String() ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: string(ruleKey),
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule) m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort() m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil return &rule, nil
} }
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
ruleID := rule.ID() ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID return r.id == string(ruleKey)
}) })
if idx < 0 { if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID) return fmt.Errorf("route rule not found in slice: %s", ruleKey)
} }
m.routeRules = slices.Delete(m.routeRules, idx, idx+1) m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil return nil
} }
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. // SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {

View File

@@ -0,0 +1,376 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
} }
} }
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) { func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -5,6 +5,8 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"os"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -16,9 +18,18 @@ const (
maxBatchSize = 1024 * 16 maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2 maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second defaultFlushInterval = 2 * time.Second
logChannelSize = 1000 defaultLogChanSize = 1000
) )
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32 type Level uint32
const ( const (
@@ -69,7 +80,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger { func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{ l := &Logger{
output: logrusLogger.Out, output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize), msgChannel: make(chan logMessage, getLogChannelSize()),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() any { New: func() any {

View File

@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case // Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 { if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2])) sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2])) sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
} else { } else {
// Fallback for other lengths // Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 { for i := 0; i < len(oldBytes)-1; i += 2 {

View File

@@ -5,20 +5,18 @@ package configurer
import ( import (
"net" "net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
) )
func openUAPI(deviceName string) (net.Listener, error) { func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName) uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil { if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err return nil, err
} }
listener, err := ipc.UAPIListen(deviceName, uapiSock) listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil { if err != nil {
log.Errorf("failed to listen on uapi socket: %v", err) _ = uapiSock.Close()
return nil, err return nil, err
} }

View File

@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg return wgCfg
} }
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key") log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey) key, err := wgtypes.ParseKey(privateKey)
@@ -558,7 +566,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue continue
} }
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]")) host, portStr, err := net.SplitHostPort(val)
if err != nil { if err != nil {
log.Errorf("failed to parse endpoint: %v", err) log.Errorf("failed to parse endpoint: %v", err)
continue continue

View File

@@ -29,8 +29,9 @@ type PacketFilter interface {
type FilteredDevice struct { type FilteredDevice struct {
tun.Device tun.Device
filter PacketFilter filter PacketFilter
mutex sync.RWMutex mutex sync.RWMutex
closeOnce sync.Once
} }
// newDeviceFilter constructor function // newDeviceFilter constructor function
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
} }
} }
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature // Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil { if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -79,10 +79,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
_ = tunIface.Close() if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
return nil, fmt.Errorf("error configuring interface: %s", err) return nil, fmt.Errorf("error configuring interface: %s", err)
} }

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
} }
if nbnetstack.IsEnabled() {
return errors.FormatErrorOrNil(result)
}
if err := w.waitUntilRemoved(); err != nil { if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
if err := w.Destroy(); err != nil { if err := w.Destroy(); err != nil {

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
} }
}() }()
return nsTunDev, tunNet, nil return t.tundev, tunNet, nil
} }
func (t *NetStackTun) Close() error { func (t *NetStackTun) Close() error {

View File

@@ -8,8 +8,6 @@ import (
"net" "net"
"sync" "sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -26,16 +24,6 @@ const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -253,63 +241,3 @@ generatePort:
} }
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var dstIP net.IP
var rawConn net.PacketConn
if endpointAddr.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
dstIP = localHostNetIPv4
rawConn = p.rawConnIPv4
} else {
// IPv6 path
if p.rawConnIPv6 == nil {
return fmt.Errorf("IPv6 raw socket not available")
}
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpointAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
dstIP = localHostNetIPv6
rawConn = p.rawConnIPv6
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,12 +10,89 @@ import (
"net" "net"
"sync" "sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool paused bool
pausedCond *sync.Cond pausedCond *sync.Cond
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr p.wgRelayedEndpointAddr = addr
return err p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
@@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
log.Errorf("failed to start package redirection, endpoint is nil") log.Errorf("failed to start package redirection, endpoint is nil")
return return
} }
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal() p.pausedCond.Signal()
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
@@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait() p.pausedCond.Wait()
} }
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) err = p.sendPkg(buf[:n], p.headerCurrentUsed)
p.pausedCond.L.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
@@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
return n, nil return n, nil
} }
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
}) })
} }
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) { func TestPortInfoEmpty(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(), PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded(), KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
c.statusRecorder.UpdateLocalPeerState(localPeerState) c.statusRecorder.UpdateLocalPeerState(localPeerState)
@@ -330,8 +331,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
close(runningChan) select {
runningChan = nil case <-runningChan:
default:
close(runningChan)
}
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -0,0 +1,60 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -0,0 +1,8 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -0,0 +1,121 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false, matchSubdomains: false,
shouldMatch: false, shouldMatch: false,
}, },
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@@ -9,9 +9,13 @@ import (
"io" "io"
"net/netip" "net/netip"
"os/exec" "os/exec"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -20,6 +24,7 @@ import (
const ( const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
globalIPv4State = "State:/Network/Global/IPv4" globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS" primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains" keySupplementalMatchDomains = "SupplementalMatchDomains"
@@ -33,11 +38,22 @@ const (
searchSuffix = "Search" searchSuffix = "Search"
matchSuffix = "Match" matchSuffix = "Match"
localSuffix = "Local" localSuffix = "Local"
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
maxDomainsPerResolverEntry = 50
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
maxDomainBytesPerResolverEntry = 1500
) )
type systemConfigurator struct { type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
} }
func newHostManager() (*systemConfigurator, error) { func newHostManager() (*systemConfigurator, error) {
@@ -79,28 +95,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) if err := s.removeKeysContaining(matchSuffix); err != nil {
var err error log.Warnf("failed to remove old match keys: %v", err)
if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey)
} }
if err != nil { if len(matchDomains) != 0 {
return fmt.Errorf("add match domains: %w", err) if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
return fmt.Errorf("add match domains: %w", err)
}
} }
s.updateState(stateManager) s.updateState(stateManager)
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if err := s.removeKeysContaining(searchSuffix); err != nil {
if len(searchDomains) != 0 { log.Warnf("failed to remove old search keys: %v", err)
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey)
} }
if err != nil { if len(searchDomains) != 0 {
return fmt.Errorf("add search domains: %w", err) if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
} }
s.updateState(stateManager) s.updateState(stateManager)
@@ -144,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error {
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string { func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 { if len(s.createdKeys) == 0 {
// return defaults for startup calls return s.discoverExistingKeys()
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
} }
keys := make([]string, 0, len(s.createdKeys)) keys := make([]string, 0, len(s.createdKeys))
@@ -155,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
return keys return keys
} }
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
func (s *systemConfigurator) discoverExistingKeys() []string {
dnsKeys, err := getSystemDNSKeys()
if err != nil {
log.Errorf("failed to get system DNS keys: %v", err)
return nil
}
var keys []string
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
if strings.Contains(dnsKeys, key) {
keys = append(keys, key)
}
}
for _, suffix := range []string{searchSuffix, matchSuffix} {
for i := 0; ; i++ {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
if !strings.Contains(dnsKeys, key) {
break
}
keys = append(keys, key)
}
}
return keys
}
// getSystemDNSKeys gets all DNS keys
func getSystemDNSKeys() (string, error) {
command := "list .*DNS\nquit\n"
out, err := runSystemConfigCommand(command)
if err != nil {
return "", err
}
return string(out), nil
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key) line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line)) _, err := runSystemConfigCommand(wrapCommand(line))
@@ -179,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error {
return nil return nil
} }
if err := s.addSearchDomains( domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
localKey, if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, return fmt.Errorf("add local dns state: %w", err)
); err != nil {
return fmt.Errorf("add search domains: %w", err)
} }
s.createdKeys[localKey] = struct{}{}
return nil return nil
} }
@@ -218,6 +268,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
var dnsSettings SystemDNSSettings var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false inSearchDomainsArray := false
inServerAddressesArray := false inServerAddressesArray := false
@@ -244,9 +295,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
dnsSettings.ServerIP = ip.Unmap() ip = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
} }
} }
} }
@@ -258,31 +312,90 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port // default to 53 port
dnsSettings.ServerPort = DefaultPort dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil return dnsSettings, nil
} }
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
err := s.addDNSState(key, domains, ip, port, true) s.mu.RLock()
if err != nil { defer s.mu.RUnlock()
return fmt.Errorf("add dns state: %w", err) return slices.Clone(s.origNameservers)
}
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
s.createdKeys[key] = struct{}{}
return nil
} }
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error { // splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
err := s.addDNSState(key, domains, dnsServer, port, false) func splitDomainsIntoBatches(domains []string) [][]string {
if err != nil { if len(domains) == 0 {
return fmt.Errorf("add dns state: %w", err) return nil
} }
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) var batches [][]string
var current []string
currentBytes := 0
s.createdKeys[key] = struct{}{} for _, d := range domains {
domainLen := len(d)
newBytes := currentBytes + domainLen
if currentBytes > 0 {
newBytes++ // space separator
}
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
batches = append(batches, current)
current = nil
currentBytes = 0
}
current = append(current, d)
if currentBytes > 0 {
currentBytes += 1 + domainLen
} else {
currentBytes = domainLen
}
}
if len(current) > 0 {
batches = append(batches, current)
}
return batches
}
// removeKeysContaining removes all created keys that contain the given substring.
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
var toRemove []string
for key := range s.createdKeys {
if strings.Contains(key, suffix) {
toRemove = append(toRemove, key)
}
}
var multiErr *multierror.Error
for _, key := range toRemove {
if err := s.removeKeyFromSystemConfig(key); err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
}
}
return nberrors.FormatErrorOrNil(multiErr)
}
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
batches := splitDomainsIntoBatches(domains)
for i, batch := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
domainsStr := strings.Join(batch, " ")
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
return fmt.Errorf("add dns state for batch %d: %w", i, err)
}
s.createdKeys[key] = struct{}{}
}
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
return nil return nil
} }
@@ -345,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error {
if out, err := cmd.CombinedOutput(); err != nil { if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out) return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
} }
log.Info("flushed DNS cache") log.Info("flushed DNS cache")
return nil return nil
} }

View File

@@ -3,7 +3,10 @@
package dns package dns
import ( import (
"bufio"
"bytes"
"context" "context"
"fmt"
"net/netip" "net/netip"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@@ -49,17 +52,22 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
require.NoError(t, sm.PersistState(context.Background())) require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
// Collect all created keys for cleanup verification
createdKeys := make([]string, 0, len(configurator.createdKeys))
for key := range configurator.createdKeys {
createdKeys = append(createdKeys, key)
}
defer func() { defer func() {
for _, key := range []string{searchKey, matchKey, localKey} { for _, key := range createdKeys {
_ = removeTestDNSKey(key) _ = removeTestDNSKey(key)
} }
_ = removeTestDNSKey(localKey)
}() }()
for _, key := range []string{searchKey, matchKey, localKey} { for _, key := range createdKeys {
exists, err := checkDNSKeyExists(key) exists, err := checkDNSKeyExists(key)
require.NoError(t, err) require.NoError(t, err)
if exists { if exists {
@@ -83,13 +91,223 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
err = shutdownState.Cleanup() err = shutdownState.Cleanup()
require.NoError(t, err) require.NoError(t, err)
for _, key := range []string{searchKey, matchKey, localKey} { for _, key := range createdKeys {
exists, err := checkDNSKeyExists(key) exists, err := checkDNSKeyExists(key)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key) assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
} }
} }
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
func generateShortDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
label := ""
n := i
for {
label = string(rune('a'+n%26)) + label
n = n/26 - 1
if n < 0 {
break
}
}
domains = append(domains, label+".com")
}
return domains
}
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
func generateLongDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
}
return domains
}
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
func readDomainsFromKey(t *testing.T, key string) []string {
t.Helper()
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
out, err := cmd.Output()
require.NoError(t, err, "scutil show should succeed")
var domains []string
inArray := false
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
inArray = true
continue
}
if inArray {
if line == "}" {
break
}
// lines look like: "0 : a.com"
parts := strings.SplitN(line, " : ", 2)
if len(parts) == 2 {
domains = append(domains, parts[1])
}
}
}
require.NoError(t, scanner.Err())
return domains
}
func TestSplitDomainsIntoBatches(t *testing.T) {
tests := []struct {
name string
domains []string
expectedCount int
checkAllPresent bool
}{
{
name: "empty",
domains: nil,
expectedCount: 0,
},
{
name: "under_limit",
domains: generateShortDomains(10),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "at_element_limit",
domains: generateShortDomains(50),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "over_element_limit",
domains: generateShortDomains(51),
expectedCount: 2,
checkAllPresent: true,
},
{
name: "triple_element_limit",
domains: generateShortDomains(150),
expectedCount: 3,
checkAllPresent: true,
},
{
name: "long_domains_hit_byte_limit",
domains: generateLongDomains(50),
checkAllPresent: true,
},
{
name: "500_short_domains",
domains: generateShortDomains(500),
expectedCount: 10,
checkAllPresent: true,
},
{
name: "500_long_domains",
domains: generateLongDomains(500),
checkAllPresent: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
batches := splitDomainsIntoBatches(tc.domains)
if tc.expectedCount > 0 {
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
}
// Verify each batch respects limits
for i, batch := range batches {
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
"batch %d exceeds element limit", i)
totalBytes := 0
for j, d := range batch {
if j > 0 {
totalBytes++
}
totalBytes += len(d)
}
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
}
if tc.checkAllPresent {
var all []string
for _, batch := range batches {
all = append(all, batch...)
}
assert.Equal(t, tc.domains, all, "all domains should be present in order")
}
})
}
}
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
// and verifies all domains are readable across multiple scutil keys.
func TestMatchDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
testCases := []struct {
name string
count int
generator func(int) []string
}{
{"short_10", 10, generateShortDomains},
{"short_50", 50, generateShortDomains},
{"short_100", 100, generateShortDomains},
{"short_200", 200, generateShortDomains},
{"short_500", 500, generateShortDomains},
{"long_10", 10, generateLongDomains},
{"long_50", 50, generateLongDomains},
{"long_100", 100, generateLongDomains},
{"long_200", 200, generateLongDomains},
{"long_500", 500, generateLongDomains},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
defer func() {
for key := range configurator.createdKeys {
_ = removeTestDNSKey(key)
}
}()
domains := tc.generator(tc.count)
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
require.NoError(t, err)
batches := splitDomainsIntoBatches(domains)
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
// Read back all domains from all batched keys
var got []string
for i := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
require.True(t, exists, "key %s should exist", key)
got = append(got, readDomainsFromKey(t, key)...)
}
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
assert.Equal(t, tc.count, len(got), "all domains should be readable")
assert.Equal(t, domains, got, "domains should match in order")
})
}
}
func checkDNSKeyExists(key string) (bool, error) { func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath) cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
@@ -109,3 +327,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput() _, err := cmd.CombinedOutput()
return err return err
} }
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
cleanup := func() {
_ = sm.Stop(context.Background())
for key := range configurator.createdKeys {
_ = removeTestDNSKey(key)
}
// Also clean up old-format keys and local key in case they exist
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -42,6 +42,8 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions" dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8 dnsPolicyConfigConfigOptionsValue = 0x8
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
@@ -198,10 +200,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
// Update count even on error to ensure cleanup covers partially created rules
r.nrptEntryCount = count
if err != nil { if err != nil {
return fmt.Errorf("add dns match policy: %w", err) return fmt.Errorf("add dns match policy: %w", err)
} }
r.nrptEntryCount = count
} else { } else {
r.nrptEntryCount = 0 r.nrptEntryCount = 0
} }
@@ -239,23 +242,33 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain} // We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
}
batchDomains := domains[i:end]
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
} }
// Increment immediately so the caller's cleanup path knows about this rule
ruleIndex++
if r.gpo { if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err) return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
} }
} }
log.Debugf("added NRPT entry for domain: %s", domain) log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
} }
if r.gpo { if r.gpo {
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
} }
} }
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains) log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
return len(domains), nil return ruleIndex, nil
} }
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -12,6 +12,7 @@ import (
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up // TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes. // when the number of match domains decreases between configuration changes.
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping registry integration test in short mode") t.Skip("skipping registry integration test in short mode")
@@ -37,51 +38,60 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
gpo: false, gpo: false,
} }
config5 := HostDNSConfig{ // Create 125 domains which will result in 3 NRPT rules (50+50+25)
ServerIP: testIP, domains125 := make([]DomainConfig, 125)
Domains: []DomainConfig{ for i := 0; i < 125; i++ {
{Domain: "domain1.com", MatchOnly: true}, domains125[i] = DomainConfig{
{Domain: "domain2.com", MatchOnly: true}, Domain: fmt.Sprintf("domain%d.com", i+1),
{Domain: "domain3.com", MatchOnly: true}, MatchOnly: true,
{Domain: "domain4.com", MatchOnly: true}, }
{Domain: "domain5.com", MatchOnly: true},
},
} }
err = cfg.applyDNSConfig(config5, nil) config125 := HostDNSConfig{
ServerIP: testIP,
Domains: domains125,
}
err = cfg.applyDNSConfig(config125, nil)
require.NoError(t, err) require.NoError(t, err)
// Verify all 5 entries exist // Verify 3 NRPT rules exist
for i := 0; i < 5; i++ { assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
for i := 0; i < 3; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err) require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i) assert.True(t, exists, "NRPT rule %d should exist after first config", i)
} }
config2 := HostDNSConfig{ // Reduce to 75 domains which will result in 2 NRPT rules (50+25)
domains75 := make([]DomainConfig, 75)
for i := 0; i < 75; i++ {
domains75[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config75 := HostDNSConfig{
ServerIP: testIP, ServerIP: testIP,
Domains: []DomainConfig{ Domains: domains75,
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
} }
err = cfg.applyDNSConfig(config2, nil) err = cfg.applyDNSConfig(config75, nil)
require.NoError(t, err) require.NoError(t, err)
// Verify first 2 entries exist // Verify first 2 NRPT rules exist
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err) require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i) assert.True(t, exists, "NRPT rule %d should exist after second config", i)
} }
// Verify entries 2-4 are cleaned up // Verify rule 2 is cleaned up
for i := 2; i < 5; i++ { exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) require.NoError(t, err)
require.NoError(t, err) assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
} }
func registryKeyExists(path string) (bool, error) { func registryKeyExists(path string) (bool, error) {
@@ -97,6 +107,106 @@ func registryKeyExists(path string) (bool, error) {
} }
func cleanupRegistryKeys(*testing.T) { func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10} // Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
_ = cfg.removeDNSMatchPolicies() _ = cfg.removeDNSMatchPolicies()
} }
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
} }
} }
if serverDomains.Flow != "" { // Flow receiver domain is intentionally excluded from caching.
domains = append(domains, serverDomains.Flow) // Cloud providers may rotate the IP behind this domain; a stale cached record
} // causes TLS certificate verification failures on reconnect.
for _, stun := range serverDomains.Stuns { for _, stun := range serverDomains.Stuns {
if stun != "" { if stun != "" {

View File

@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
} }
assert.Len(t, resolver.GetCachedDomains(), 3) assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing) // Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
partialDomains := dnsconfig.ServerDomains{ partialDomains := dnsconfig.ServerDomains{
Flow: "github.com", Flow: "github.com",
} }
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
t.Skipf("Skipping test due to DNS resolution failure: %v", err) t.Skipf("Skipping test due to DNS resolution failure: %v", err)
} }
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
finalDomains := resolver.GetCachedDomains() finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
domainStrings := make([]string, len(finalDomains)) domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains { for i, d := range finalDomains {
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
assert.Contains(t, domainStrings, "example.org") assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com") assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com") assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com") assert.NotContains(t, domainStrings, "github.com")
} }

View File

@@ -84,3 +84,18 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil return nil
} }
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}
// CancelBatch mock implementation of CancelBatch from Server interface
func (m *MockServer) CancelBatch() {
// Mock implementation - no-op
}

View File

@@ -6,7 +6,9 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url" "net/url"
"os"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
@@ -27,6 +29,8 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface { type ReadyListener interface {
OnReady() OnReady()
@@ -41,6 +45,9 @@ type IosDnsManager interface {
type Server interface { type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int) RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int) DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
CancelBatch()
Initialize() error Initialize() error
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
@@ -83,6 +90,7 @@ type DefaultServer struct {
currentConfigHash uint64 currentConfigHash uint64
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver mgmtCacheResolver *mgmt.Resolver
@@ -230,7 +238,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter // convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++ s.extraDomains[toZone(domain)]++
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -259,9 +269,41 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone) delete(s.extraDomains, zone)
} }
} }
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig() s.applyHostConfig()
} }
// CancelBatch cancels batch mode without applying accumulated changes.
// This is useful when operations fail partway through and you want to
// discard partial state rather than applying it.
func (s *DefaultServer) CancelBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
s.batchMode = false
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler with priority %d for %v", priority, domains) log.Debugf("deregistering handler with priority %d for %v", priority, domains)
@@ -439,6 +481,17 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability // ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds // and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() { func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap { for _, mux := range s.dnsMuxMap {
wg.Add(1) wg.Add(1)
@@ -508,6 +561,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
} }
// Always apply host config for management updates, regardless of batch mode
s.applyHostConfig() s.applyHostConfig()
s.shutdownWg.Add(1) s.shutdownWg.Add(1)
@@ -615,7 +669,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config) s.registerFallback(config)
} }
// registerFallback registers original nameservers as low-priority fallback handlers // registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) { func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok { if !ok {
@@ -624,6 +678,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers() originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 { if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return return
} }
@@ -871,6 +926,7 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
// Always apply host config when nameserver goes down, regardless of batch mode
s.applyHostConfig() s.applyHostConfig()
go func() { go func() {
@@ -906,6 +962,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority) s.registerHandler([]string{nbdns.RootZone}, handler, priority)
} }
// Always apply host config when nameserver reactivates, regardless of batch mode
s.applyHostConfig() s.applyHostConfig()
s.updateNSState(nsGroup, nil, true) s.updateNSState(nsGroup, nil, true)

View File

@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err) t.Errorf("invalid dns server instance: %s", err)
} }
if srvB != srv { mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
t.Errorf("mismatch dns instances") t.Errorf("mismatch dns instances")
} }
} }

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct { type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
} }
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil { if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m) return rw.WriteMsgFunc(m)
} }
return nil return nil
} }
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return fmt.Errorf("upstream check call error") return fmt.Errorf("upstream check call error")
} }
err := backoff.Retry(operation, exponentialBackOff) err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil { if err != nil {
log.Warn(err) if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return return
} }

View File

@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg { func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
if len(query.Question) == 0 { if len(query.Question) == 0 {
return nil return
} }
question := query.Question[0] question := query.Question[0]
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s", qname := strings.ToLower(question.Name)
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name) logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query) resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype) network := resutil.NetworkForQtype(question.Qtype)
if network == "" { if network == "" {
resp.Rcode = dns.RcodeNotImplemented resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil { f.writeResponse(logger, w, resp, qname, startTime)
logger.Errorf("failed to write DNS response: %v", err) return
}
return nil
} }
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" { if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil { f.writeResponse(logger, w, resp, qname, startTime)
logger.Errorf("failed to write DNS response: %v", err) return
}
return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype) result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil { if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result) f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return nil return
} }
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries) f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...) resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs) f.cache.set(qname, question.Qtype, result.IPs)
return resp f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id), "dns_id": fmt.Sprintf("%04x", query.Id),
}) })
resp := f.handleDNSQuery(logger, w, query) f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
} }
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id), "dns_id": fmt.Sprintf("%04x", query.Id),
}) })
resp := f.handleDNSQuery(logger, w, query) f.handleDNSQuery(logger, w, query, startTime)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
} }
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg, resp *dns.Msg,
domain string, domain string,
result resutil.LookupResult, result resutil.LookupResult,
startTime time.Time,
) { ) {
qType := question.Qtype qType := question.Qtype
qTypeName := dns.TypeToString[qType] qTypeName := dns.TypeToString[qType]
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond // NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess { if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil) f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil { f.writeResponse(logger, w, resp, domain, startTime)
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return return
} }
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil { f.writeResponse(logger, w, resp, domain, startTime)
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return return
} }
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType) verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess { if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil { f.writeResponse(logger, w, resp, domain, startTime)
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return return
} }
} }
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context. // No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" { if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else { } else {
logger.Warnf(errResolveFailed, domain, result.Err) logger.Warnf(errResolveFailed, domain, result.Err)
} }
// Write final failure response. f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
} }
// getMatchingEntries retrieves the resource IDs for a given domain. // getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
if tt.shouldResolve { if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain") require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t) mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t) mockResolver.AssertExpectations(t)
} else { } else {
if resp != nil { require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers") "Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet") mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP") mockResolver.AssertNotCalled(t, "LookupNetIP")
} }
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
// Verify response // Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve { if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain") require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer) require.NotEmpty(t, resp.Answer)
} else if resp != nil { } else {
require.NotNil(t, resp, "Expected response")
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers") "Unauthorized domain should be refused or have no answers")
} }
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA) query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
// Verify response contains all IPs // Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp) require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records") require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
}, },
} }
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
// Check the response written to the writer // Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written") require.NotNil(t, writtenResp, "Expected response to be written")
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{} q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{} w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
require.NotNil(t, resp1) require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1) require.Len(t, resp1.Answer, 1)
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure // Second query: serve from cache after upstream failure
q2 := &dns.Msg{} q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg w2 := &test.MockResponseWriter{}
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp, "expected response to be written") resp2 := w2.GetLastResponse()
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) require.NotNil(t, resp2, "expected response to be written")
require.Len(t, writtenResp.Answer, 1) require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
mockResolver.AssertExpectations(t) mockResolver.AssertExpectations(t)
} }
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{} q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA) q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{} w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
require.NotNil(t, resp1) require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1) require.Len(t, resp1.Answer, 1)
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{} q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA) q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg w2 := &test.MockResponseWriter{}
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp) resp2 := w2.GetLastResponse()
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) require.NotNil(t, resp2)
require.Len(t, writtenResp.Answer, 1) require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
mockResolver.AssertExpectations(t) mockResolver.AssertExpectations(t)
} }
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA) query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp) require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode) assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{} query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType) query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg mockWriter := &test.MockResponseWriter{}
mockWriter := &test.MockResponseWriter{ forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
// If a response was returned, it means it should be written (happens in wrapper functions) assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer { if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records") assert.Empty(t, resp.Answer, "Response should have no answer records")
} }
mockResolver.AssertExpectations(t) mockResolver.AssertExpectations(t)
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{} query := &dns.Msg{}
// Don't set any question // Don't set any question
writeCalled := false mockWriter := &test.MockResponseWriter{}
mockWriter := &test.MockResponseWriter{ forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query") assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
} }

View File

@@ -29,12 +29,14 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -52,13 +54,11 @@ import (
"github.com/netbirdio/netbird/client/internal/updatemanager" "github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec" "github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -74,7 +74,6 @@ import (
const ( const (
PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled" disableAutoUpdate = "disabled"
) )
@@ -207,7 +206,6 @@ type Engine struct {
syncRespMux sync.RWMutex syncRespMux sync.RWMutex
persistSyncResponse bool persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager flowManager nftypes.FlowManager
// auto-update // auto-update
@@ -223,6 +221,8 @@ type Engine struct {
jobExecutor *jobexec.Executor jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup jobExecutorWG sync.WaitGroup
exposeManager *expose.Manager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -265,7 +265,6 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager, stateManager: stateManager,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(), jobExecutor: jobexec.NewExecutor(),
} }
@@ -418,6 +417,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.cancel() e.cancel()
} }
e.ctx, e.cancel = context.WithCancel(e.clientCtx) e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
if err != nil { if err != nil {
@@ -543,11 +543,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes // monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor() e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1) e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() { go func() {
defer e.shutdownWg.Done() defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err) log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart() e.triggerClientRestart()
} else if err != nil { } else if err != nil {
@@ -573,9 +574,11 @@ func (e *Engine) createFirewall() error {
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) return fmt.Errorf("create firewall manager: %w", err)
return nil }
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
} }
if err := e.initFirewall(); err != nil { if err := e.initFirewall(); err != nil {
@@ -797,7 +800,7 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
disabled := autoUpdateSettings.Version == disableAutoUpdate disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled // stop and cleanup if disabled
if e.updateManager != nil && disabled { if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager") log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop() e.updateManager.Stop()
@@ -826,6 +829,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
} }
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -1015,7 +1022,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state := e.statusRecorder.GetLocalPeerState() state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String() state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String() state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = device.WireGuardModuleIsLoaded() state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.FQDN = conf.GetFqdn() state.FQDN = conf.GetFqdn()
e.statusRecorder.UpdateLocalPeerState(state) e.statusRecorder.UpdateLocalPeerState(state)
@@ -1531,7 +1538,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
IFaceDiscover: e.mobileDep.IFaceDiscover, IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager, RelayManager: e.relayManager,
SrWatcher: e.srWatcher, SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore,
} }
peerConn, err := peer.NewConn(config, serviceDependencies) peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
@@ -1554,8 +1560,10 @@ func (e *Engine) receiveSignalEvents() {
defer e.shutdownWg.Done() defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server // connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
start := time.Now()
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
gotLock := time.Since(start)
// Check context INSIDE lock to ensure atomicity with shutdown // Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil { if e.ctx.Err() != nil {
@@ -1579,6 +1587,8 @@ func (e *Engine) receiveSignalEvents() {
return err return err
} }
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
if msg.Body.Type == sProto.Body_OFFER { if msg.Body.Type == sProto.Body_OFFER {
conn.OnRemoteOffer(*offerAnswer) conn.OnRemoteOffer(*offerAnswer)
} else { } else {
@@ -1812,11 +1822,18 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager return e.routeManager
} }
// GetFirewallManager returns the firewall manager // GetFirewallManager returns the firewall manager.
func (e *Engine) GetFirewallManager() firewallManager.Manager { func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall return e.firewall
} }
// GetExposeManager returns the expose session manager.
func (e *Engine) GetExposeManager() *expose.Manager {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return e.exposeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {
@@ -1916,7 +1933,7 @@ func (e *Engine) triggerClientRestart() {
} }
func (e *Engine) startNetworkMonitor() { func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor { if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
log.Infof("Network monitor is disabled, not starting") log.Infof("Network monitor is disabled, not starting")
return return
} }

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth" sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config" sshconfig "github.com/netbirdio/netbird/client/ssh/config"
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
// updateSSHClientConfig updates the SSH client configuration with peer information // updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error { func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
if netstack.IsEnabled() {
return nil
}
peerInfo := e.extractPeerSSHInfo(remotePeers) peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 { if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update") log.Debug("no SSH-enabled peers found, skipping SSH config update")
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown // cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() { func (e *Engine) cleanupSSHConfig() {
if netstack.IsEnabled() {
return
}
configMgr := sshconfig.New() configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil { if err := configMgr.RemoveSSHClientConfig(); err != nil {

View File

@@ -0,0 +1,95 @@
package expose
import (
"context"
"time"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
)
const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
}
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol int
Pin string
Password string
UserGroups []string
}
type ManagementClient interface {
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
RenewExpose(ctx context.Context, domain string) error
StopExpose(ctx context.Context, domain string) error
}
// Manager handles expose session lifecycle via the management client.
type Manager struct {
mgmClient ManagementClient
ctx context.Context
}
// NewManager creates a new expose Manager using the given management client.
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
return &Manager{mgmClient: mgmClient, ctx: ctx}
}
// Expose creates a new expose session via the management server.
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
log.Infof("exposing service on port %d", req.Port)
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
if err != nil {
return nil, err
}
log.Infof("expose session created for %s", resp.Domain)
return fromClientExposeResponse(resp), nil
}
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
defer m.stop(domain)
for {
select {
case <-ctx.Done():
log.Infof("context canceled, stopping keep alive for %s", domain)
return nil
case <-ticker.C:
if err := m.renew(ctx, domain); err != nil {
log.Errorf("renewing expose session for %s: %v", domain, err)
return err
}
}
}
}
// renew extends the TTL of an active expose session.
func (m *Manager) renew(ctx context.Context, domain string) error {
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
defer cancel()
return m.mgmClient.RenewExpose(renewCtx, domain)
}
// stop terminates an active expose session.
func (m *Manager) stop(domain string) {
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
defer cancel()
err := m.mgmClient.StopExpose(stopCtx, domain)
if err != nil {
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
}
}

View File

@@ -0,0 +1,95 @@
package expose
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
func TestManager_Expose_Success(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return &mgm.ExposeResponse{
ServiceName: "my-service",
ServiceURL: "https://my-service.example.com",
Domain: "my-service.example.com",
}, nil
},
}
m := NewManager(context.Background(), mock)
result, err := m.Expose(context.Background(), Request{Port: 8080})
require.NoError(t, err)
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
}
func TestManager_Expose_Error(t *testing.T) {
mock := &mgm.MockClient{
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
return nil, errors.New("permission denied")
},
}
m := NewManager(context.Background(), mock)
_, err := m.Expose(context.Background(), Request{Port: 8080})
require.Error(t, err)
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
}
func TestManager_Renew_Success(t *testing.T) {
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
return nil
},
}
m := NewManager(context.Background(), mock)
err := m.renew(context.Background(), "my-service.example.com")
require.NoError(t, err)
}
func TestManager_Renew_Timeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
mock := &mgm.MockClient{
RenewExposeFunc: func(ctx context.Context, domain string) error {
return ctx.Err()
},
}
m := NewManager(ctx, mock)
err := m.renew(ctx, "my-service.example.com")
require.Error(t, err)
}
func TestNewRequest(t *testing.T) {
req := &daemonProto.ExposeServiceRequest{
Port: 8080,
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
Pin: "123456",
Password: "secret",
UserGroups: []string{"group1", "group2"},
Domain: "custom.example.com",
NamePrefix: "my-prefix",
}
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
}

View File

@@ -0,0 +1,39 @@
package expose
import (
daemonProto "github.com/netbirdio/netbird/client/proto"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
}
}
func toClientExposeRequest(req Request) mgm.ExposeRequest {
return mgm.ExposeRequest{
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: req.Protocol,
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
}
}

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg) return NewUDPListener(m.wgIface, peerCfg)
} }
// BindListener is only used on Windows and JS platforms: // BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets // - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface. // gateway points to, preventing them from reaching the loopback interface.
// BindListener bypasses this by passing data directly through the bind. // - Netstack: Allows multiple instances on the same host without port conflicts.
if runtime.GOOS != "windows" && runtime.GOOS != "js" { // BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg) return NewUDPListener(m.wgIface, peerCfg)
} }

View File

@@ -22,51 +22,56 @@ func prepareFd() (int, error) {
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for { for {
select { // Wait until fd is readable or context is cancelled, to avoid a busy-loop
case <-ctx.Done(): // when the routing socket returns EAGAIN (e.g. immediately after wakeup).
return ctx.Err() if err := waitReadable(ctx, fd); err != nil {
default: return err
buf := make([]byte, 2048) }
n, err := unix.Read(fd, buf)
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
continue
}
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
return fmt.Errorf("routing socket closed: %w", err)
}
return fmt.Errorf("read routing socket: %w", err)
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil { if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { log.Debugf("Network monitor: error parsing routing message: %v", err)
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue continue
} }
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type { switch msg.Type {
// handle route changes case unix.RTM_ADD:
case unix.RTM_ADD, syscall.RTM_DELETE: log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
route, err := parseRouteMessage(buf[:n]) return nil
if err != nil { case unix.RTM_DELETE:
log.Debugf("Network monitor: error parsing routing message: %v", err) if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
continue log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
} }
} }
} }
@@ -90,3 +95,33 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
return systemops.MsgToRoute(msg) return systemops.MsgToRoute(msg)
} }
// waitReadable blocks until fd has data to read, or ctx is cancelled.
func waitReadable(ctx context.Context, fd int) error {
var fdset unix.FdSet
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
return fmt.Errorf("fd %d out of range for FdSet", fd)
}
for {
if err := ctx.Err(); err != nil {
return err
}
fdset = unix.FdSet{}
fdset.Set(fd)
// Use a 1-second timeout so we can re-check ctx periodically.
tv := unix.Timeval{Sec: 1}
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
if err != nil {
if errors.Is(err, unix.EINTR) {
continue
}
return fmt.Errorf("select on routing socket: %w", err)
}
if n > 0 {
return nil
}
// timeout — loop back and re-check ctx
}
}

View File

@@ -3,7 +3,6 @@ package peer
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
@@ -25,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ServiceDependencies struct { type ServiceDependencies struct {
@@ -34,7 +32,6 @@ type ServiceDependencies struct {
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
RelayManager *relayClient.Manager RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher SrWatcher *guard.SRWatcher
Semaphore *semaphoregroup.SemaphoreGroup
PeerConnDispatcher *dispatcher.ConnectionDispatcher PeerConnDispatcher *dispatcher.ConnectionDispatcher
} }
@@ -111,9 +108,8 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
handshaker *Handshaker handshaker *Handshaker
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup wg sync.WaitGroup
wg sync.WaitGroup
// debug purpose // debug purpose
dumpState *stateDump dumpState *stateDump
@@ -139,7 +135,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
iFaceDiscover: services.IFaceDiscover, iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager, relayManager: services.RelayManager,
srWatcher: services.SrWatcher, srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(), statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(),
dumpState: dumpState, dumpState: dumpState,
@@ -154,15 +149,10 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open(engineCtx context.Context) error { func (conn *Conn) Open(engineCtx context.Context) error {
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.opened { if conn.opened {
conn.semaphore.Done()
return nil return nil
} }
@@ -173,7 +163,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil { if err != nil {
conn.semaphore.Done()
return err return err
} }
conn.workerICE = workerICE conn.workerICE = workerICE
@@ -207,10 +196,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1) conn.wg.Add(1)
go func() { go func() {
defer conn.wg.Done() defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent) conn.guard.Start(conn.ctx, conn.onGuardEvent)
}() }()
conn.opened = true conn.opened = true
@@ -390,6 +375,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
} }
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy) conn.handleConfigurationFailure(err, wgProxy)
@@ -402,15 +389,13 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep) conn.wgProxyRelay.RedirectAs(ep)
} }
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
func (conn *Conn) onICEStateDisconnected() { func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -430,14 +415,18 @@ func (conn *Conn) onICEStateDisconnected() {
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.Log.Infof("ICE disconnected, set Relay to active connection") conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay() conn.dumpState.SwitchToRelay()
if sessionChanged {
conn.resetEndpoint()
}
// todo consider to move after the ConfigureWGEndpoint
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
conn.Log.Errorf("failed to switch to relay conn: %v", err) conn.Log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
@@ -499,19 +488,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
wgProxy.Work() controller := isController(conn.config)
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if controller {
wgProxy.Work()
}
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err) conn.Log.Warnf("Failed to close relay connection: %v", err)
} }
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err) conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
if !controller {
conn.enableWgWatcherIfNeeded() wgProxy.Work()
}
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected() conn.statusRelay.SetConnected()
@@ -663,19 +655,6 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
} }
} }
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
maxWait := 300
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
timeout := time.NewTimer(duration)
defer timeout.Stop()
select {
case <-ctx.Done():
case <-timeout.C:
}
}
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
switch conn.currentConnPriority { switch conn.currentConnPriority {
case conntype.Relay, conntype.ICETurn: case conntype.Relay, conntype.ICETurn:
@@ -756,6 +735,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
return wgProxy, nil return wgProxy, nil
} }
func (conn *Conn) resetEndpoint() {
if !isController(conn.config) {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}
}
func (conn *Conn) isReadyToUpgrade() bool { func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
} }
@@ -861,9 +851,3 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
func wgConfigWorkaround() {
time.Sleep(100 * time.Millisecond)
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var testDispatcher = dispatcher.NewConnectionDispatcher() var testDispatcher = dispatcher.NewConnectionDispatcher()
@@ -53,7 +52,6 @@ func TestConn_GetKey(t *testing.T) {
sd := ServiceDependencies{ sd := ServiceDependencies{
SrWatcher: swWatcher, SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher, PeerConnDispatcher: testDispatcher,
} }
conn, err := NewConn(connConf, sd) conn, err := NewConn(connConf, sd)
@@ -71,7 +69,6 @@ func TestConn_OnRemoteOffer(t *testing.T) {
sd := ServiceDependencies{ sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"), StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher, SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher, PeerConnDispatcher: testDispatcher,
} }
conn, err := NewConn(connConf, sd) conn, err := NewConn(connConf, sd)
@@ -110,7 +107,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
sd := ServiceDependencies{ sd := ServiceDependencies{
StatusRecorder: NewRecorder("https://mgm"), StatusRecorder: NewRecorder("https://mgm"),
SrWatcher: swWatcher, SrWatcher: swWatcher,
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
PeerConnDispatcher: testDispatcher, PeerConnDispatcher: testDispatcher,
} }
conn, err := NewConn(connConf, sd) conn, err := NewConn(connConf, sd)

View File

@@ -34,28 +34,27 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
} }
} }
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()
if e.initiator { if e.initiator {
e.log.Debugf("configure up WireGuard as initiatr") e.log.Debugf("configure up WireGuard as initiator")
return e.updateWireGuardPeer(addr, presharedKey) return e.configureAsInitiator(addr, presharedKey)
} }
e.log.Debugf("configure up WireGuard as responder")
return e.configureAsResponder(addr, presharedKey)
}
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
// prevent to run new update while cancel the previous update // prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate() e.waitForCloseTheDelayedUpdate()
var ctx context.Context return e.updateWireGuardPeer(addr, presharedKey)
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
e.log.Debugf("configure up WireGuard and wait for handshake")
return e.updateWireGuardPeer(nil, presharedKey)
} }
func (e *EndpointUpdater) RemoveWgPeer() error { func (e *EndpointUpdater) RemoveWgPeer() error {
@@ -66,6 +65,38 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
} }
func (e *EndpointUpdater) RemoveEndpointAddress() error {
e.mu.Lock()
defer e.mu.Unlock()
e.waitForCloseTheDelayedUpdate()
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
return err
}
return nil
}
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
e.log.Debugf("configure up WireGuard and wait for handshake")
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
e.waitForCloseTheDelayedUpdate()
return err
}
return nil
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil { if e.cancelFunc == nil {
return return
@@ -101,3 +132,9 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
presharedKey, presharedKey,
) )
} }
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
func wgConfigWorkaround() {
time.Sleep(100 * time.Millisecond)
}

View File

@@ -2,6 +2,7 @@ package ice
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"time" "time"
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
once sync.Once once sync.Once
} }
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive() iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout() iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err return nil, err
} }
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil return &ThreadSafeAgent{Agent: agent}, nil
} }
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) { func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil { if err != nil {

View File

@@ -32,6 +32,8 @@ type WGWatcher struct {
enabled bool enabled bool
muEnabled sync.RWMutex muEnabled sync.RWMutex
resetCh chan struct{}
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
wgIfaceStater: wgIfaceStater, wgIfaceStater: wgIfaceStater,
peerKey: peerKey, peerKey: peerKey,
stateDump: stateDump, stateDump: stateDump,
resetCh: make(chan struct{}, 1),
} }
} }
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
return w.enabled return w.enabled
} }
// Reset signals the watcher that the WireGuard peer has been reset and a new
// handshake is expected. This restarts the handshake timeout from scratch.
func (w *WGWatcher) Reset() {
select {
case w.resetCh <- struct{}{}:
default:
}
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) { func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started") w.log.Infof("WireGuard watcher started")
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
w.stateDump.WGcheckSuccess() w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime) w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-w.resetCh:
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
lastHandshake = time.Time{}
enabledTime = time.Now()
timer.Stop()
timer.Reset(wgHandshakeOvertime)
case <-ctx.Done(): case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped") w.log.Infof("WireGuard watcher stopped")
return return

View File

@@ -52,8 +52,9 @@ type WorkerICE struct {
// increase by one when disconnecting the agent // increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer // with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection // Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID sessionID ICESessionID
muxAgent sync.Mutex remoteSessionChanged bool
muxAgent sync.Mutex
localUfrag string localUfrag string
localPwd string localPwd string
@@ -106,9 +107,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return return
} }
w.log.Debugf("agent already exists, recreate the connection") w.log.Debugf("agent already exists, recreate the connection")
w.remoteSessionChanged = true
w.agentDialerCancel() w.agentDialerCancel()
if err := w.agent.Close(); err != nil { if w.agent != nil {
w.log.Warnf("failed to close ICE agent: %s", err) if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
} }
sessionID, err := NewICESessionID() sessionID, err := NewICESessionID()
@@ -304,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
} }
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) { func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
cancel() cancel()
if err := agent.Close(); err != nil { if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock()
sessionChanged := w.remoteSessionChanged
w.remoteSessionChanged = false
if w.agent == agent { if w.agent == agent {
// consider to remove from here and move to the OnNewOffer // consider to remove from here and move to the OnNewOffer
@@ -323,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
w.agentConnecting = false w.agentConnecting = false
w.remoteSessionID = "" w.remoteSessionID = ""
} }
w.muxAgent.Unlock() return sessionChanged
} }
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -424,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority // notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel) sessionChanged := w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected(sessionChanged)
} }
default: default:
return return

View File

@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
configDir := filepath.Join(DefaultConfigPathDir, username) configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) { if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0600); err != nil { if err := os.MkdirAll(configDir, 0700); err != nil {
return "", err return "", err
} }
} }
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
return configDir, nil return configDir, nil
} }
func fileExists(path string) bool { func fileExists(path string) (bool, error) {
_, err := os.Stat(path) _, err := os.Stat(path)
return !os.IsNotExist(err) if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -252,7 +258,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
} }
if config.AdminURL == nil { if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL) log.Infof("using default Admin URL %s", DefaultAdminURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil { if err != nil {
return false, err return false, err
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
// UpdateConfig update existing configuration according to input configuration and return with the configuration // UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) { func UpdateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath) return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
} }
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one // UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath) err = util.EnforcePermission(input.ConfigPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
} }
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) { func readConfig(configPath string, createIfMissing bool) (*Config, error) {
if fileExists(configPath) { configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes. // DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox). // Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) { func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) { configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {

View File

@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
} }
profPath := filepath.Join(configDir, profileName+".json") profPath := filepath.Join(configDir, profileName+".json")
if fileExists(profPath) { profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists return ErrProfileAlreadyExists
} }
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName) return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
} }
profPath := filepath.Join(configDir, profileName+".json") profPath := filepath.Join(configDir, profileName+".json")
if !fileExists(profPath) { profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if !profileExists {
return ErrProfileNotFound return ErrProfileNotFound
} }

View File

@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
} }
stateFile := filepath.Join(configDir, profileName+".state.json") stateFile := filepath.Join(configDir, profileName+".state.json")
if !fileExists(stateFile) { stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
}
if !stateFileExists {
return nil, errors.New("profile state file does not exist") return nil, errors.New("profile state file does not exist")
} }

View File

@@ -263,8 +263,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
case <-closer: case <-closer:
return return
case routerStates := <-subscription.Events(): case routerStates := <-subscription.Events():
peerStateUpdate <- routerStates select {
log.Debugf("triggered route state update for Peer: %s", peerKey) case peerStateUpdate <- routerStates:
log.Debugf("triggered route state update for Peer: %s", peerKey)
case <-ctx.Done():
return
case <-closer:
return
}
} }
} }
} }

View File

@@ -351,6 +351,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
logger.Errorf("failed to update domain prefixes: %v", err) logger.Errorf("failed to update domain prefixes: %v", err)
} }
// Allow time for route changes to be applied before sending
// the DNS response (relevant on iOS where setTunnelNetworkSettings
// is asynchronous).
waitForRouteSettlement(logger)
d.replaceIPsInDNSResponse(r, newPrefixes, logger) d.replaceIPsInDNSResponse(r, newPrefixes, logger)
} }
} }

View File

@@ -0,0 +1,20 @@
//go:build ios
package dnsinterceptor
import (
"time"
log "github.com/sirupsen/logrus"
)
const routeSettleDelay = 500 * time.Millisecond
// waitForRouteSettlement introduces a short delay on iOS to allow
// setTunnelNetworkSettings to apply route changes before the DNS
// response reaches the application. Without this, the first request
// to a newly resolved domain may bypass the tunnel.
func waitForRouteSettlement(logger *log.Entry) {
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
time.Sleep(routeSettleDelay)
}

View File

@@ -0,0 +1,12 @@
//go:build !ios
package dnsinterceptor
import log "github.com/sirupsen/logrus"
func waitForRouteSettlement(_ *log.Entry) {
// No-op on non-iOS platforms: route changes are applied synchronously by
// the kernel, so no settlement delay is needed before the DNS response
// reaches the application. The delay is only required on iOS where
// setTunnelNetworkSettings applies routes asynchronously.
}

View File

@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
} }
func (m *DefaultManager) setupRefCounters(useNoop bool) { func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) { func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
}, },
func(prefix netip.Prefix, _ struct{}) error { func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) return m.sysOps.RemoveVPNRoute(prefix, toInterface())
}, },
) )
@@ -337,6 +346,23 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
} }
var merr *multierror.Error var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
batchStarted := false
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
batchStarted = true
defer func() {
if merr != nil {
// On error, cancel batch to discard partial DNS state
m.dnsServer.CancelBatch()
} else {
// On success, apply accumulated DNS changes
m.dnsServer.EndBatch()
}
}()
}
for id, handler := range toRemove { for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil { if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err)) merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
@@ -367,6 +393,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
m.activeRoutes[id] = handler m.activeRoutes[id] = handler
} }
_ = batchStarted // Mark as used
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -4,16 +4,17 @@ package systemops
import ( import (
"strings" "strings"
"syscall"
"golang.org/x/sys/unix"
) )
// filterRoutesByFlags returns true if the route message should be ignored based on its flags. // filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool { func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 { if routeMessageFlags&unix.RTF_UP == 0 {
return true return true
} }
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 { if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
return true return true
} }
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string { func formatBSDFlags(flags int) string {
var flagStrs []string var flagStrs []string
if flags&syscall.RTF_UP != 0 { if flags&unix.RTF_UP != 0 {
flagStrs = append(flagStrs, "U") flagStrs = append(flagStrs, "U")
} }
if flags&syscall.RTF_GATEWAY != 0 { if flags&unix.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G") flagStrs = append(flagStrs, "G")
} }
if flags&syscall.RTF_HOST != 0 { if flags&unix.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H") flagStrs = append(flagStrs, "H")
} }
if flags&syscall.RTF_REJECT != 0 { if flags&unix.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R") flagStrs = append(flagStrs, "R")
} }
if flags&syscall.RTF_DYNAMIC != 0 { if flags&unix.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D") flagStrs = append(flagStrs, "D")
} }
if flags&syscall.RTF_MODIFIED != 0 { if flags&unix.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M") flagStrs = append(flagStrs, "M")
} }
if flags&syscall.RTF_STATIC != 0 { if flags&unix.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S") flagStrs = append(flagStrs, "S")
} }
if flags&syscall.RTF_LLINFO != 0 { if flags&unix.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L") flagStrs = append(flagStrs, "L")
} }
if flags&syscall.RTF_LOCAL != 0 { if flags&unix.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l") flagStrs = append(flagStrs, "l")
} }
if flags&syscall.RTF_BLACKHOLE != 0 { if flags&unix.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B") flagStrs = append(flagStrs, "B")
} }
if flags&syscall.RTF_CLONING != 0 { if flags&unix.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C") flagStrs = append(flagStrs, "C")
} }
if flags&syscall.RTF_WASCLONED != 0 { if flags&unix.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W") flagStrs = append(flagStrs, "W")
} }
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 { if len(flagStrs) == 0 {
return "-" return "-"

View File

@@ -4,17 +4,18 @@ package systemops
import ( import (
"strings" "strings"
"syscall"
"golang.org/x/sys/unix"
) )
// filterRoutesByFlags returns true if the route message should be ignored based on its flags. // filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool { func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 { if routeMessageFlags&unix.RTF_UP == 0 {
return true return true
} }
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 // NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 { if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
return true return true
} }
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string { func formatBSDFlags(flags int) string {
var flagStrs []string var flagStrs []string
if flags&syscall.RTF_UP != 0 { if flags&unix.RTF_UP != 0 {
flagStrs = append(flagStrs, "U") flagStrs = append(flagStrs, "U")
} }
if flags&syscall.RTF_GATEWAY != 0 { if flags&unix.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G") flagStrs = append(flagStrs, "G")
} }
if flags&syscall.RTF_HOST != 0 { if flags&unix.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H") flagStrs = append(flagStrs, "H")
} }
if flags&syscall.RTF_REJECT != 0 { if flags&unix.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R") flagStrs = append(flagStrs, "R")
} }
if flags&syscall.RTF_DYNAMIC != 0 { if flags&unix.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D") flagStrs = append(flagStrs, "D")
} }
if flags&syscall.RTF_MODIFIED != 0 { if flags&unix.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M") flagStrs = append(flagStrs, "M")
} }
if flags&syscall.RTF_STATIC != 0 { if flags&unix.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S") flagStrs = append(flagStrs, "S")
} }
if flags&syscall.RTF_LLINFO != 0 { if flags&unix.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L") flagStrs = append(flagStrs, "L")
} }
if flags&syscall.RTF_LOCAL != 0 { if flags&unix.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l") flagStrs = append(flagStrs, "l")
} }
if flags&syscall.RTF_BLACKHOLE != 0 { if flags&unix.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B") flagStrs = append(flagStrs, "B")
} }
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0 // Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 { if len(flagStrs) == 0 {
return "-" return "-"

View File

@@ -0,0 +1,80 @@
package handler
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
)
type Agent interface {
Up(ctx context.Context) error
Down(ctx context.Context) error
Status() (internal.StatusType, error)
}
type SleepHandler struct {
agent Agent
mu sync.Mutex
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
sleepTriggeredDown bool
}
func New(agent Agent) *SleepHandler {
return &SleepHandler{
agent: agent,
}
}
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.sleepTriggeredDown {
log.Info("skipping up because wasn't sleep down")
return nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown = false
log.Info("running up after wake up")
err := s.agent.Up(ctx)
if err != nil {
log.Errorf("running up failed: %v", err)
return err
}
log.Info("running up command executed successfully")
return nil
}
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
status, err := s.agent.Status()
if err != nil {
return err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
return nil
}
log.Info("running down after system started sleeping")
if err = s.agent.Down(ctx); err != nil {
log.Errorf("running down failed: %v", err)
return err
}
s.sleepTriggeredDown = true
log.Info("running down executed successfully")
return nil
}

View File

@@ -0,0 +1,153 @@
package handler
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockAgent struct {
upErr error
downErr error
statusErr error
status internal.StatusType
upCalls int
}
func (m *mockAgent) Up(_ context.Context) error {
m.upCalls++
return m.upErr
}
func (m *mockAgent) Down(_ context.Context) error {
return m.downErr
}
func (m *mockAgent) Status() (internal.StatusType, error) {
return m.status, m.statusErr
}
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
agent := &mockAgent{status: status}
return New(agent), agent
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
h, _ := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
// Even if Up fails, flag should be reset
_ = h.HandleWakeUp(context.Background())
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
}
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
agent.upErr = errors.New("up failed")
err := h.HandleWakeUp(context.Background())
assert.ErrorIs(t, err, agent.upErr)
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
}
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
h, agent := newHandler(internal.StatusIdle)
h.sleepTriggeredDown = true
_ = h.HandleWakeUp(context.Background())
err := h.HandleWakeUp(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.False(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, _ := newHandler(tt.status)
err := h.HandleSleep(context.Background())
require.NoError(t, err)
assert.True(t, h.sleepTriggeredDown)
})
}
}
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
agent := &mockAgent{statusErr: errors.New("status error")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.statusErr)
assert.False(t, h.sleepTriggeredDown)
}
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
h := New(agent)
err := h.HandleSleep(context.Background())
assert.ErrorIs(t, err, agent.downErr)
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
}

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
) )
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine // WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms") return false, errors.New("not supported on mobile platforms")
} }
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" { if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor") log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name") return false, errors.New("empty interface name")

View File

@@ -2,7 +2,10 @@
package NetBirdSDK package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer" import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
// EnvList is an exported struct to be bound by gomobile // EnvList is an exported struct to be bound by gomobile
type EnvList struct { type EnvList struct {
@@ -32,3 +35,13 @@ func (el *EnvList) AllItems() map[string]string {
func GetEnvKeyNBForceRelay() string { func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay return peer.EnvKeyNBForceRelay
} }
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
func GetEnvKeyNBLazyConn() string {
return lazyconn.EnvEnableLazyConn
}
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
func GetEnvKeyNBInactivityThreshold() string {
return lazyconn.EnvInactivityThreshold
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.36.6 // protoc-gen-go v1.36.6
// protoc v6.32.1 // protoc v6.33.3
// source: daemon.proto // source: daemon.proto
package proto package proto
@@ -88,6 +88,58 @@ func (LogLevel) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{0} return file_daemon_proto_rawDescGZIP(), []int{0}
} }
type ExposeProtocol int32
const (
ExposeProtocol_EXPOSE_HTTP ExposeProtocol = 0
ExposeProtocol_EXPOSE_HTTPS ExposeProtocol = 1
ExposeProtocol_EXPOSE_TCP ExposeProtocol = 2
ExposeProtocol_EXPOSE_UDP ExposeProtocol = 3
)
// Enum value maps for ExposeProtocol.
var (
ExposeProtocol_name = map[int32]string{
0: "EXPOSE_HTTP",
1: "EXPOSE_HTTPS",
2: "EXPOSE_TCP",
3: "EXPOSE_UDP",
}
ExposeProtocol_value = map[string]int32{
"EXPOSE_HTTP": 0,
"EXPOSE_HTTPS": 1,
"EXPOSE_TCP": 2,
"EXPOSE_UDP": 3,
}
)
func (x ExposeProtocol) Enum() *ExposeProtocol {
p := new(ExposeProtocol)
*p = x
return p
}
func (x ExposeProtocol) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (ExposeProtocol) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor()
}
func (ExposeProtocol) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1]
}
func (x ExposeProtocol) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use ExposeProtocol.Descriptor instead.
func (ExposeProtocol) EnumDescriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{1}
}
// avoid collision with loglevel enum // avoid collision with loglevel enum
type OSLifecycleRequest_CycleType int32 type OSLifecycleRequest_CycleType int32
@@ -122,11 +174,11 @@ func (x OSLifecycleRequest_CycleType) String() string {
} }
func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor { func (OSLifecycleRequest_CycleType) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[1].Descriptor() return file_daemon_proto_enumTypes[2].Descriptor()
} }
func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType { func (OSLifecycleRequest_CycleType) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[1] return &file_daemon_proto_enumTypes[2]
} }
func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber { func (x OSLifecycleRequest_CycleType) Number() protoreflect.EnumNumber {
@@ -174,11 +226,11 @@ func (x SystemEvent_Severity) String() string {
} }
func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor { func (SystemEvent_Severity) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[2].Descriptor() return file_daemon_proto_enumTypes[3].Descriptor()
} }
func (SystemEvent_Severity) Type() protoreflect.EnumType { func (SystemEvent_Severity) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[2] return &file_daemon_proto_enumTypes[3]
} }
func (x SystemEvent_Severity) Number() protoreflect.EnumNumber { func (x SystemEvent_Severity) Number() protoreflect.EnumNumber {
@@ -229,11 +281,11 @@ func (x SystemEvent_Category) String() string {
} }
func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor { func (SystemEvent_Category) Descriptor() protoreflect.EnumDescriptor {
return file_daemon_proto_enumTypes[3].Descriptor() return file_daemon_proto_enumTypes[4].Descriptor()
} }
func (SystemEvent_Category) Type() protoreflect.EnumType { func (SystemEvent_Category) Type() protoreflect.EnumType {
return &file_daemon_proto_enumTypes[3] return &file_daemon_proto_enumTypes[4]
} }
func (x SystemEvent_Category) Number() protoreflect.EnumNumber { func (x SystemEvent_Category) Number() protoreflect.EnumNumber {
@@ -5600,6 +5652,224 @@ func (x *InstallerResultResponse) GetErrorMsg() string {
return "" return ""
} }
type ExposeServiceRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Port uint32 `protobuf:"varint,1,opt,name=port,proto3" json:"port,omitempty"`
Protocol ExposeProtocol `protobuf:"varint,2,opt,name=protocol,proto3,enum=daemon.ExposeProtocol" json:"protocol,omitempty"`
Pin string `protobuf:"bytes,3,opt,name=pin,proto3" json:"pin,omitempty"`
Password string `protobuf:"bytes,4,opt,name=password,proto3" json:"password,omitempty"`
UserGroups []string `protobuf:"bytes,5,rep,name=user_groups,json=userGroups,proto3" json:"user_groups,omitempty"`
Domain string `protobuf:"bytes,6,opt,name=domain,proto3" json:"domain,omitempty"`
NamePrefix string `protobuf:"bytes,7,opt,name=name_prefix,json=namePrefix,proto3" json:"name_prefix,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceRequest) Reset() {
*x = ExposeServiceRequest{}
mi := &file_daemon_proto_msgTypes[85]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceRequest) ProtoMessage() {}
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[85]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{85}
}
func (x *ExposeServiceRequest) GetPort() uint32 {
if x != nil {
return x.Port
}
return 0
}
func (x *ExposeServiceRequest) GetProtocol() ExposeProtocol {
if x != nil {
return x.Protocol
}
return ExposeProtocol_EXPOSE_HTTP
}
func (x *ExposeServiceRequest) GetPin() string {
if x != nil {
return x.Pin
}
return ""
}
func (x *ExposeServiceRequest) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
func (x *ExposeServiceRequest) GetUserGroups() []string {
if x != nil {
return x.UserGroups
}
return nil
}
func (x *ExposeServiceRequest) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *ExposeServiceRequest) GetNamePrefix() string {
if x != nil {
return x.NamePrefix
}
return ""
}
type ExposeServiceEvent struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Event:
//
// *ExposeServiceEvent_Ready
Event isExposeServiceEvent_Event `protobuf_oneof:"event"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceEvent) Reset() {
*x = ExposeServiceEvent{}
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceEvent) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceEvent) ProtoMessage() {}
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{86}
}
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
if x != nil {
return x.Event
}
return nil
}
func (x *ExposeServiceEvent) GetReady() *ExposeServiceReady {
if x != nil {
if x, ok := x.Event.(*ExposeServiceEvent_Ready); ok {
return x.Ready
}
}
return nil
}
type isExposeServiceEvent_Event interface {
isExposeServiceEvent_Event()
}
type ExposeServiceEvent_Ready struct {
Ready *ExposeServiceReady `protobuf:"bytes,1,opt,name=ready,proto3,oneof"`
}
func (*ExposeServiceEvent_Ready) isExposeServiceEvent_Event() {}
type ExposeServiceReady struct {
state protoimpl.MessageState `protogen:"open.v1"`
ServiceName string `protobuf:"bytes,1,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
ServiceUrl string `protobuf:"bytes,2,opt,name=service_url,json=serviceUrl,proto3" json:"service_url,omitempty"`
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ExposeServiceReady) Reset() {
*x = ExposeServiceReady{}
mi := &file_daemon_proto_msgTypes[87]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ExposeServiceReady) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ExposeServiceReady) ProtoMessage() {}
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[87]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{87}
}
func (x *ExposeServiceReady) GetServiceName() string {
if x != nil {
return x.ServiceName
}
return ""
}
func (x *ExposeServiceReady) GetServiceUrl() string {
if x != nil {
return x.ServiceUrl
}
return ""
}
func (x *ExposeServiceReady) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
type PortInfo_Range struct { type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5610,7 +5880,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() { func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{} *x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[86] mi := &file_daemon_proto_msgTypes[89]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@@ -5622,7 +5892,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {} func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[86] mi := &file_daemon_proto_msgTypes[89]
if x != nil { if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@@ -6149,7 +6419,25 @@ const file_daemon_proto_rawDesc = "" +
"\x16InstallerResultRequest\"O\n" + "\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" + "\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" + "\berrorMsg\x18\x02 \x01(\tR\berrorMsg\"\xe6\x01\n" +
"\x14ExposeServiceRequest\x12\x12\n" +
"\x04port\x18\x01 \x01(\rR\x04port\x122\n" +
"\bprotocol\x18\x02 \x01(\x0e2\x16.daemon.ExposeProtocolR\bprotocol\x12\x10\n" +
"\x03pin\x18\x03 \x01(\tR\x03pin\x12\x1a\n" +
"\bpassword\x18\x04 \x01(\tR\bpassword\x12\x1f\n" +
"\vuser_groups\x18\x05 \x03(\tR\n" +
"userGroups\x12\x16\n" +
"\x06domain\x18\x06 \x01(\tR\x06domain\x12\x1f\n" +
"\vname_prefix\x18\a \x01(\tR\n" +
"namePrefix\"Q\n" +
"\x12ExposeServiceEvent\x122\n" +
"\x05ready\x18\x01 \x01(\v2\x1a.daemon.ExposeServiceReadyH\x00R\x05readyB\a\n" +
"\x05event\"p\n" +
"\x12ExposeServiceReady\x12!\n" +
"\fservice_name\x18\x01 \x01(\tR\vserviceName\x12\x1f\n" +
"\vservice_url\x18\x02 \x01(\tR\n" +
"serviceUrl\x12\x16\n" +
"\x06domain\x18\x03 \x01(\tR\x06domain*b\n" +
"\bLogLevel\x12\v\n" + "\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" + "\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" +
@@ -6158,7 +6446,14 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" + "\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" + "\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xdd\x14\n" + "\x05TRACE\x10\a*S\n" +
"\x0eExposeProtocol\x12\x0f\n" +
"\vEXPOSE_HTTP\x10\x00\x12\x10\n" +
"\fEXPOSE_HTTPS\x10\x01\x12\x0e\n" +
"\n" +
"EXPOSE_TCP\x10\x02\x12\x0e\n" +
"\n" +
"EXPOSE_UDP\x10\x032\xac\x15\n" +
"\rDaemonService\x126\n" + "\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6197,7 +6492,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" + "\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" + "\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" + "\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3" "\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00\x12M\n" +
"\rExposeService\x12\x1c.daemon.ExposeServiceRequest\x1a\x1a.daemon.ExposeServiceEvent\"\x000\x01B\bZ\x06/protob\x06proto3"
var ( var (
file_daemon_proto_rawDescOnce sync.Once file_daemon_proto_rawDescOnce sync.Once
@@ -6211,214 +6507,222 @@ func file_daemon_proto_rawDescGZIP() []byte {
return file_daemon_proto_rawDescData return file_daemon_proto_rawDescData
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 5)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 91)
var file_daemon_proto_goTypes = []any{ var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType (ExposeProtocol)(0), // 1: daemon.ExposeProtocol
(SystemEvent_Severity)(0), // 2: daemon.SystemEvent.Severity (OSLifecycleRequest_CycleType)(0), // 2: daemon.OSLifecycleRequest.CycleType
(SystemEvent_Category)(0), // 3: daemon.SystemEvent.Category (SystemEvent_Severity)(0), // 3: daemon.SystemEvent.Severity
(*EmptyRequest)(nil), // 4: daemon.EmptyRequest (SystemEvent_Category)(0), // 4: daemon.SystemEvent.Category
(*OSLifecycleRequest)(nil), // 5: daemon.OSLifecycleRequest (*EmptyRequest)(nil), // 5: daemon.EmptyRequest
(*OSLifecycleResponse)(nil), // 6: daemon.OSLifecycleResponse (*OSLifecycleRequest)(nil), // 6: daemon.OSLifecycleRequest
(*LoginRequest)(nil), // 7: daemon.LoginRequest (*OSLifecycleResponse)(nil), // 7: daemon.OSLifecycleResponse
(*LoginResponse)(nil), // 8: daemon.LoginResponse (*LoginRequest)(nil), // 8: daemon.LoginRequest
(*WaitSSOLoginRequest)(nil), // 9: daemon.WaitSSOLoginRequest (*LoginResponse)(nil), // 9: daemon.LoginResponse
(*WaitSSOLoginResponse)(nil), // 10: daemon.WaitSSOLoginResponse (*WaitSSOLoginRequest)(nil), // 10: daemon.WaitSSOLoginRequest
(*UpRequest)(nil), // 11: daemon.UpRequest (*WaitSSOLoginResponse)(nil), // 11: daemon.WaitSSOLoginResponse
(*UpResponse)(nil), // 12: daemon.UpResponse (*UpRequest)(nil), // 12: daemon.UpRequest
(*StatusRequest)(nil), // 13: daemon.StatusRequest (*UpResponse)(nil), // 13: daemon.UpResponse
(*StatusResponse)(nil), // 14: daemon.StatusResponse (*StatusRequest)(nil), // 14: daemon.StatusRequest
(*DownRequest)(nil), // 15: daemon.DownRequest (*StatusResponse)(nil), // 15: daemon.StatusResponse
(*DownResponse)(nil), // 16: daemon.DownResponse (*DownRequest)(nil), // 16: daemon.DownRequest
(*GetConfigRequest)(nil), // 17: daemon.GetConfigRequest (*DownResponse)(nil), // 17: daemon.DownResponse
(*GetConfigResponse)(nil), // 18: daemon.GetConfigResponse (*GetConfigRequest)(nil), // 18: daemon.GetConfigRequest
(*PeerState)(nil), // 19: daemon.PeerState (*GetConfigResponse)(nil), // 19: daemon.GetConfigResponse
(*LocalPeerState)(nil), // 20: daemon.LocalPeerState (*PeerState)(nil), // 20: daemon.PeerState
(*SignalState)(nil), // 21: daemon.SignalState (*LocalPeerState)(nil), // 21: daemon.LocalPeerState
(*ManagementState)(nil), // 22: daemon.ManagementState (*SignalState)(nil), // 22: daemon.SignalState
(*RelayState)(nil), // 23: daemon.RelayState (*ManagementState)(nil), // 23: daemon.ManagementState
(*NSGroupState)(nil), // 24: daemon.NSGroupState (*RelayState)(nil), // 24: daemon.RelayState
(*SSHSessionInfo)(nil), // 25: daemon.SSHSessionInfo (*NSGroupState)(nil), // 25: daemon.NSGroupState
(*SSHServerState)(nil), // 26: daemon.SSHServerState (*SSHSessionInfo)(nil), // 26: daemon.SSHSessionInfo
(*FullStatus)(nil), // 27: daemon.FullStatus (*SSHServerState)(nil), // 27: daemon.SSHServerState
(*ListNetworksRequest)(nil), // 28: daemon.ListNetworksRequest (*FullStatus)(nil), // 28: daemon.FullStatus
(*ListNetworksResponse)(nil), // 29: daemon.ListNetworksResponse (*ListNetworksRequest)(nil), // 29: daemon.ListNetworksRequest
(*SelectNetworksRequest)(nil), // 30: daemon.SelectNetworksRequest (*ListNetworksResponse)(nil), // 30: daemon.ListNetworksResponse
(*SelectNetworksResponse)(nil), // 31: daemon.SelectNetworksResponse (*SelectNetworksRequest)(nil), // 31: daemon.SelectNetworksRequest
(*IPList)(nil), // 32: daemon.IPList (*SelectNetworksResponse)(nil), // 32: daemon.SelectNetworksResponse
(*Network)(nil), // 33: daemon.Network (*IPList)(nil), // 33: daemon.IPList
(*PortInfo)(nil), // 34: daemon.PortInfo (*Network)(nil), // 34: daemon.Network
(*ForwardingRule)(nil), // 35: daemon.ForwardingRule (*PortInfo)(nil), // 35: daemon.PortInfo
(*ForwardingRulesResponse)(nil), // 36: daemon.ForwardingRulesResponse (*ForwardingRule)(nil), // 36: daemon.ForwardingRule
(*DebugBundleRequest)(nil), // 37: daemon.DebugBundleRequest (*ForwardingRulesResponse)(nil), // 37: daemon.ForwardingRulesResponse
(*DebugBundleResponse)(nil), // 38: daemon.DebugBundleResponse (*DebugBundleRequest)(nil), // 38: daemon.DebugBundleRequest
(*GetLogLevelRequest)(nil), // 39: daemon.GetLogLevelRequest (*DebugBundleResponse)(nil), // 39: daemon.DebugBundleResponse
(*GetLogLevelResponse)(nil), // 40: daemon.GetLogLevelResponse (*GetLogLevelRequest)(nil), // 40: daemon.GetLogLevelRequest
(*SetLogLevelRequest)(nil), // 41: daemon.SetLogLevelRequest (*GetLogLevelResponse)(nil), // 41: daemon.GetLogLevelResponse
(*SetLogLevelResponse)(nil), // 42: daemon.SetLogLevelResponse (*SetLogLevelRequest)(nil), // 42: daemon.SetLogLevelRequest
(*State)(nil), // 43: daemon.State (*SetLogLevelResponse)(nil), // 43: daemon.SetLogLevelResponse
(*ListStatesRequest)(nil), // 44: daemon.ListStatesRequest (*State)(nil), // 44: daemon.State
(*ListStatesResponse)(nil), // 45: daemon.ListStatesResponse (*ListStatesRequest)(nil), // 45: daemon.ListStatesRequest
(*CleanStateRequest)(nil), // 46: daemon.CleanStateRequest (*ListStatesResponse)(nil), // 46: daemon.ListStatesResponse
(*CleanStateResponse)(nil), // 47: daemon.CleanStateResponse (*CleanStateRequest)(nil), // 47: daemon.CleanStateRequest
(*DeleteStateRequest)(nil), // 48: daemon.DeleteStateRequest (*CleanStateResponse)(nil), // 48: daemon.CleanStateResponse
(*DeleteStateResponse)(nil), // 49: daemon.DeleteStateResponse (*DeleteStateRequest)(nil), // 49: daemon.DeleteStateRequest
(*SetSyncResponsePersistenceRequest)(nil), // 50: daemon.SetSyncResponsePersistenceRequest (*DeleteStateResponse)(nil), // 50: daemon.DeleteStateResponse
(*SetSyncResponsePersistenceResponse)(nil), // 51: daemon.SetSyncResponsePersistenceResponse (*SetSyncResponsePersistenceRequest)(nil), // 51: daemon.SetSyncResponsePersistenceRequest
(*TCPFlags)(nil), // 52: daemon.TCPFlags (*SetSyncResponsePersistenceResponse)(nil), // 52: daemon.SetSyncResponsePersistenceResponse
(*TracePacketRequest)(nil), // 53: daemon.TracePacketRequest (*TCPFlags)(nil), // 53: daemon.TCPFlags
(*TraceStage)(nil), // 54: daemon.TraceStage (*TracePacketRequest)(nil), // 54: daemon.TracePacketRequest
(*TracePacketResponse)(nil), // 55: daemon.TracePacketResponse (*TraceStage)(nil), // 55: daemon.TraceStage
(*SubscribeRequest)(nil), // 56: daemon.SubscribeRequest (*TracePacketResponse)(nil), // 56: daemon.TracePacketResponse
(*SystemEvent)(nil), // 57: daemon.SystemEvent (*SubscribeRequest)(nil), // 57: daemon.SubscribeRequest
(*GetEventsRequest)(nil), // 58: daemon.GetEventsRequest (*SystemEvent)(nil), // 58: daemon.SystemEvent
(*GetEventsResponse)(nil), // 59: daemon.GetEventsResponse (*GetEventsRequest)(nil), // 59: daemon.GetEventsRequest
(*SwitchProfileRequest)(nil), // 60: daemon.SwitchProfileRequest (*GetEventsResponse)(nil), // 60: daemon.GetEventsResponse
(*SwitchProfileResponse)(nil), // 61: daemon.SwitchProfileResponse (*SwitchProfileRequest)(nil), // 61: daemon.SwitchProfileRequest
(*SetConfigRequest)(nil), // 62: daemon.SetConfigRequest (*SwitchProfileResponse)(nil), // 62: daemon.SwitchProfileResponse
(*SetConfigResponse)(nil), // 63: daemon.SetConfigResponse (*SetConfigRequest)(nil), // 63: daemon.SetConfigRequest
(*AddProfileRequest)(nil), // 64: daemon.AddProfileRequest (*SetConfigResponse)(nil), // 64: daemon.SetConfigResponse
(*AddProfileResponse)(nil), // 65: daemon.AddProfileResponse (*AddProfileRequest)(nil), // 65: daemon.AddProfileRequest
(*RemoveProfileRequest)(nil), // 66: daemon.RemoveProfileRequest (*AddProfileResponse)(nil), // 66: daemon.AddProfileResponse
(*RemoveProfileResponse)(nil), // 67: daemon.RemoveProfileResponse (*RemoveProfileRequest)(nil), // 67: daemon.RemoveProfileRequest
(*ListProfilesRequest)(nil), // 68: daemon.ListProfilesRequest (*RemoveProfileResponse)(nil), // 68: daemon.RemoveProfileResponse
(*ListProfilesResponse)(nil), // 69: daemon.ListProfilesResponse (*ListProfilesRequest)(nil), // 69: daemon.ListProfilesRequest
(*Profile)(nil), // 70: daemon.Profile (*ListProfilesResponse)(nil), // 70: daemon.ListProfilesResponse
(*GetActiveProfileRequest)(nil), // 71: daemon.GetActiveProfileRequest (*Profile)(nil), // 71: daemon.Profile
(*GetActiveProfileResponse)(nil), // 72: daemon.GetActiveProfileResponse (*GetActiveProfileRequest)(nil), // 72: daemon.GetActiveProfileRequest
(*LogoutRequest)(nil), // 73: daemon.LogoutRequest (*GetActiveProfileResponse)(nil), // 73: daemon.GetActiveProfileResponse
(*LogoutResponse)(nil), // 74: daemon.LogoutResponse (*LogoutRequest)(nil), // 74: daemon.LogoutRequest
(*GetFeaturesRequest)(nil), // 75: daemon.GetFeaturesRequest (*LogoutResponse)(nil), // 75: daemon.LogoutResponse
(*GetFeaturesResponse)(nil), // 76: daemon.GetFeaturesResponse (*GetFeaturesRequest)(nil), // 76: daemon.GetFeaturesRequest
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest (*GetFeaturesResponse)(nil), // 77: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse (*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest (*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse (*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest (*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse (*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest (*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse (*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest (*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse (*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest (*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse (*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
nil, // 89: daemon.Network.ResolvedIPsEntry (*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range (*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
nil, // 91: daemon.SystemEvent.MetadataEntry (*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration (*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp nil, // 93: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 94: daemon.PortInfo.Range
nil, // 95: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 96: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 97: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType 2, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 96, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 28, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 97, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 97, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration 96, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo 26, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 23, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState 22, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
20, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 21, // 9: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
19, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState 20, // 10: daemon.FullStatus.peers:type_name -> daemon.PeerState
23, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState 24, // 11: daemon.FullStatus.relays:type_name -> daemon.RelayState
24, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 25, // 12: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent 58, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState 27, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network 34, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 93, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 94, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 35, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 35, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule 36, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 21: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 0, // 22: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
43, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State 44, // 23: daemon.ListStatesResponse.states:type_name -> daemon.State
52, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags 53, // 24: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 55, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 3, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category 4, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp 97, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 95, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent 58, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 96, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 71, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 1, // 33: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 92, // 34: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
9, // 35: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 33, // 35: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
11, // 36: daemon.DaemonService.Up:input_type -> daemon.UpRequest 8, // 36: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
13, // 37: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 10, // 37: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
15, // 38: daemon.DaemonService.Down:input_type -> daemon.DownRequest 12, // 38: daemon.DaemonService.Up:input_type -> daemon.UpRequest
17, // 39: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest 14, // 39: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
28, // 40: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest 16, // 40: daemon.DaemonService.Down:input_type -> daemon.DownRequest
30, // 41: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest 18, // 41: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
30, // 42: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest 29, // 42: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest
4, // 43: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest 31, // 43: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest
37, // 44: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 31, // 44: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest
39, // 45: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 5, // 45: daemon.DaemonService.ForwardingRules:input_type -> daemon.EmptyRequest
41, // 46: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest 38, // 46: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
44, // 47: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest 40, // 47: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
46, // 48: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest 42, // 48: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
48, // 49: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest 45, // 49: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
50, // 50: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest 47, // 50: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
53, // 51: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest 49, // 51: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
56, // 52: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest 51, // 52: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
58, // 53: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest 54, // 53: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
60, // 54: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest 57, // 54: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
62, // 55: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest 59, // 55: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
64, // 56: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest 61, // 56: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
66, // 57: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest 63, // 57: daemon.DaemonService.SetConfig:input_type -> daemon.SetConfigRequest
68, // 58: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest 65, // 58: daemon.DaemonService.AddProfile:input_type -> daemon.AddProfileRequest
71, // 59: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest 67, // 59: daemon.DaemonService.RemoveProfile:input_type -> daemon.RemoveProfileRequest
73, // 60: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest 69, // 60: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
75, // 61: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest 72, // 61: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest 74, // 62: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest 76, // 63: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest 78, // 64: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest 80, // 65: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest 82, // 66: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest 84, // 67: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest 86, // 68: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 6, // 69: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 88, // 70: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse 90, // 71: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 9, // 72: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse 11, // 73: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 13, // 74: daemon.DaemonService.Up:output_type -> daemon.UpResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse 15, // 75: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse 17, // 76: daemon.DaemonService.Down:output_type -> daemon.DownResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 19, // 77: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse 30, // 78: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 32, // 79: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 32, // 80: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 37, // 81: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse 39, // 82: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse 41, // 83: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse 43, // 84: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse 46, // 85: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse 48, // 86: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent 50, // 87: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse 52, // 88: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse 56, // 89: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse 58, // 90: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse 60, // 91: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse 62, // 92: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse 64, // 93: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse 66, // 94: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse 68, // 95: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse 70, // 96: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse 73, // 97: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse 75, // 98: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse 77, // 99: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse 79, // 100: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse 81, // 101: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse 83, // 102: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse 85, // 103: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
69, // [69:104] is the sub-list for method output_type 87, // 104: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
34, // [34:69] is the sub-list for method input_type 7, // 105: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
34, // [34:34] is the sub-list for extension type_name 89, // 106: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
34, // [34:34] is the sub-list for extension extendee 91, // 107: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
0, // [0:34] is the sub-list for field type_name 72, // [72:108] is the sub-list for method output_type
36, // [36:72] is the sub-list for method input_type
36, // [36:36] is the sub-list for extension type_name
36, // [36:36] is the sub-list for extension extendee
0, // [0:36] is the sub-list for field type_name
} }
func init() { file_daemon_proto_init() } func init() { file_daemon_proto_init() }
@@ -6439,13 +6743,16 @@ func file_daemon_proto_init() {
file_daemon_proto_msgTypes[58].OneofWrappers = []any{} file_daemon_proto_msgTypes[58].OneofWrappers = []any{}
file_daemon_proto_msgTypes[69].OneofWrappers = []any{} file_daemon_proto_msgTypes[69].OneofWrappers = []any{}
file_daemon_proto_msgTypes[75].OneofWrappers = []any{} file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
(*ExposeServiceEvent_Ready)(nil),
}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{ File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4, NumEnums: 5,
NumMessages: 88, NumMessages: 91,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -103,6 +103,9 @@ service DaemonService {
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {} rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
} }
@@ -801,3 +804,32 @@ message InstallerResultResponse {
bool success = 1; bool success = 1;
string errorMsg = 2; string errorMsg = 2;
} }
enum ExposeProtocol {
EXPOSE_HTTP = 0;
EXPOSE_HTTPS = 1;
EXPOSE_TCP = 2;
EXPOSE_UDP = 3;
}
message ExposeServiceRequest {
uint32 port = 1;
ExposeProtocol protocol = 2;
string pin = 3;
string password = 4;
repeated string user_groups = 5;
string domain = 6;
string name_prefix = 7;
}
message ExposeServiceEvent {
oneof event {
ExposeServiceReady ready = 1;
}
}
message ExposeServiceReady {
string service_name = 1;
string service_url = 2;
string domain = 3;
}

View File

@@ -76,6 +76,8 @@ type DaemonServiceClient interface {
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -424,6 +426,38 @@ func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *Instal
return out, nil return out, nil
} }
func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (DaemonService_ExposeServiceClient, error) {
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[1], "/daemon.DaemonService/ExposeService", opts...)
if err != nil {
return nil, err
}
x := &daemonServiceExposeServiceClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type DaemonService_ExposeServiceClient interface {
Recv() (*ExposeServiceEvent, error)
grpc.ClientStream
}
type daemonServiceExposeServiceClient struct {
grpc.ClientStream
}
func (x *daemonServiceExposeServiceClient) Recv() (*ExposeServiceEvent, error) {
m := new(ExposeServiceEvent)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -486,6 +520,8 @@ type DaemonServiceServer interface {
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -598,6 +634,9 @@ func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLi
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) { func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
} }
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, DaemonService_ExposeServiceServer) error {
return status.Errorf(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1244,6 +1283,27 @@ func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Cont
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(ExposeServiceRequest)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(DaemonServiceServer).ExposeService(m, &daemonServiceExposeServiceServer{stream})
}
type DaemonService_ExposeServiceServer interface {
Send(*ExposeServiceEvent) error
grpc.ServerStream
}
type daemonServiceExposeServiceServer struct {
grpc.ServerStream
}
func (x *daemonServiceExposeServiceServer) Send(m *ExposeServiceEvent) error {
return x.ServerStream.SendMsg(m)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -1394,6 +1454,11 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
Handler: _DaemonService_SubscribeEvents_Handler, Handler: _DaemonService_SubscribeEvents_Handler,
ServerStreams: true, ServerStreams: true,
}, },
{
StreamName: "ExposeService",
Handler: _DaemonService_ExposeService_Handler,
ServerStreams: true,
},
}, },
Metadata: "daemon.proto", Metadata: "daemon.proto",
} }

View File

@@ -1,77 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
return s.handleWakeUp(callerCtx)
case proto.OSLifecycleRequest_SLEEP:
return s.handleSleep(callerCtx)
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
if !s.sleepTriggeredDown.Load() {
log.Info("skipping up because wasn't sleep down")
return &proto.OSLifecycleResponse{}, nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown.Store(false)
log.Info("running up after wake up")
_, err := s.Up(callerCtx, &proto.UpRequest{})
if err != nil {
log.Errorf("running up failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
log.Info("running up command executed successfully")
return &proto.OSLifecycleResponse{}, nil
}
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
s.mutex.Lock()
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, nil
}
s.mutex.Unlock()
log.Info("running down after system started sleeping")
_, err = s.Down(callerCtx, &proto.DownRequest{})
if err != nil {
log.Errorf("running down failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
s.sleepTriggeredDown.Store(true)
log.Info("running down executed successfully")
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -1,219 +0,0 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
ctx := internal.CtxInitState(context.Background())
return &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
}
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
s := newTestServer()
// sleepTriggeredDown is false by default
assert.False(t, s.sleepTriggeredDown.Load())
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnecting)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnected)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
}
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
s := newTestServer()
// Manually set the flag to simulate prior sleep down
s.sleepTriggeredDown.Store(true)
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
}
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
s := newTestServer()
// First wakeup without prior sleep - should be no-op
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
// Simulate prior sleep
s.sleepTriggeredDown.Store(true)
// First wakeup after sleep - should reset flag
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load())
// Second wakeup - should be no-op
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
s := newTestServer()
resp, err := s.handleWakeUp(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
s := newTestServer()
s.sleepTriggeredDown.Store(true)
// Even if Up fails, flag should be reset
_, _ = s.handleWakeUp(context.Background())
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
resp, err := s.handleSleep(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.handleSleep(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, s.sleepTriggeredDown.Load())
})
}
}

View File

@@ -21,7 +21,9 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -85,8 +87,7 @@ type Server struct {
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down sleepHandler *sleephandler.SleepHandler
sleepTriggeredDown atomic.Bool
jwtCache *jwtCache jwtCache *jwtCache
} }
@@ -100,7 +101,7 @@ type oauthAuthFlow struct {
// New server instance constructor. // New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server { func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
return &Server{ s := &Server{
rootCtx: ctx, rootCtx: ctx,
logFile: logFile, logFile: logFile,
persistSyncResponse: true, persistSyncResponse: true,
@@ -110,6 +111,10 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
updateSettingsDisabled: updateSettingsDisabled, updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(), jwtCache: newJWTCache(),
} }
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
return s
} }
func (s *Server) Start() error { func (s *Server) Start() error {
@@ -636,8 +641,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return s.waitForUp(callerCtx) return s.waitForUp(callerCtx)
} }
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err) log.Warnf(errRestoreResidualState, err)
} }
@@ -649,10 +652,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
// not in the progress or already successfully established connection. // not in the progress or already successfully established connection.
status, err := state.Status() status, err := state.Status()
if err != nil { if err != nil {
s.mutex.Unlock()
return nil, err return nil, err
} }
if status != internal.StatusIdle { if status != internal.StatusIdle {
s.mutex.Unlock()
return nil, fmt.Errorf("up already in progress: current status %s", status) return nil, fmt.Errorf("up already in progress: current status %s", status)
} }
@@ -669,17 +674,20 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.actCancel = cancel s.actCancel = cancel
if s.config == nil { if s.config == nil {
s.mutex.Unlock()
return nil, fmt.Errorf("config is not defined, please call login command first") return nil, fmt.Errorf("config is not defined, please call login command first")
} }
activeProf, err := s.profileManager.GetActiveProfileState() activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err) log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err) return nil, fmt.Errorf("failed to get active profile state: %w", err)
} }
if msg != nil && msg.ProfileName != nil { if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil { if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err) log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err) return nil, fmt.Errorf("failed to switch profile: %w", err)
} }
@@ -687,6 +695,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
activeProf, err = s.profileManager.GetActiveProfileState() activeProf, err = s.profileManager.GetActiveProfileState()
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile state: %v", err) log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err) return nil, fmt.Errorf("failed to get active profile state: %w", err)
} }
@@ -695,6 +704,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
config, _, err := s.getConfig(activeProf) config, _, err := s.getConfig(activeProf)
if err != nil { if err != nil {
s.mutex.Unlock()
log.Errorf("failed to get active profile config: %v", err) log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err) return nil, fmt.Errorf("failed to get active profile config: %w", err)
} }
@@ -713,6 +723,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
} }
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan) go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
s.mutex.Unlock()
return s.waitForUp(callerCtx) return s.waitForUp(callerCtx)
} }
@@ -838,14 +849,26 @@ func (s *Server) cleanupConnection() error {
if s.actCancel == nil { if s.actCancel == nil {
return ErrServiceNotUp return ErrServiceNotUp
} }
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
// to skip the engine shutdown entirely.
var engine *internal.Engine
if s.connectClient != nil {
engine = s.connectClient.Engine()
}
s.actCancel() s.actCancel()
if s.connectClient == nil { if s.connectClient == nil {
return nil return nil
} }
if err := s.connectClient.Stop(); err != nil { if engine != nil {
return err if err := engine.Stop(); err != nil {
return err
}
} }
s.connectClient = nil s.connectClient = nil
@@ -1312,6 +1335,60 @@ func (s *Server) WaitJWTToken(
}, nil }, nil
} }
// ExposeService exposes a local port via the NetBird reverse proxy.
func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.DaemonService_ExposeServiceServer) error {
s.mutex.Lock()
if !s.clientRunning {
s.mutex.Unlock()
return gstatus.Errorf(codes.FailedPrecondition, "client is not running, run 'netbird up' first")
}
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")
}
ctx := srv.Context()
exposeCtx, exposeCancel := context.WithTimeout(ctx, 30*time.Second)
defer exposeCancel()
mgmReq := expose.NewRequest(req)
result, err := mgr.Expose(exposeCtx, *mgmReq)
if err != nil {
return err
}
if err := srv.Send(&proto.ExposeServiceEvent{
Event: &proto.ExposeServiceEvent_Ready{
Ready: &proto.ExposeServiceReady{
ServiceName: result.ServiceName,
ServiceUrl: result.ServiceURL,
Domain: result.Domain,
},
},
}); err != nil {
return err
}
err = mgr.KeepAlive(ctx, result.Domain)
if err != nil {
return err
}
return nil
}
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false return false
@@ -1548,9 +1625,14 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error { func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection") log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate) client := internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) client.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
if err := client.Run(runningChan, s.logFile); err != nil {
return err return err
} }
return nil return nil

View File

@@ -0,0 +1,187 @@
package server
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
return &Server{
rootCtx: context.Background(),
statusRecorder: peer.NewRecorder(""),
}
}
func newDummyConnectClient(ctx context.Context) *internal.ConnectClient {
return internal.NewConnectClient(ctx, nil, nil, false)
}
// TestConnectSetsClientWithMutex validates that connect() sets s.connectClient
// under mutex protection so concurrent readers see a consistent value.
func TestConnectSetsClientWithMutex(t *testing.T) {
s := newTestServer()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Manually simulate what connect() does (without calling Run which panics without full setup)
client := newDummyConnectClient(ctx)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
// Verify the assignment is visible under mutex
s.mutex.Lock()
assert.Equal(t, client, s.connectClient, "connectClient should be set")
s.mutex.Unlock()
}
// TestConcurrentConnectClientAccess validates that concurrent reads of
// s.connectClient under mutex don't race with a write.
func TestConcurrentConnectClientAccess(t *testing.T) {
s := newTestServer()
ctx := context.Background()
client := newDummyConnectClient(ctx)
var wg sync.WaitGroup
nilCount := 0
setCount := 0
var mu sync.Mutex
// Start readers
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.mutex.Lock()
c := s.connectClient
s.mutex.Unlock()
mu.Lock()
defer mu.Unlock()
if c == nil {
nilCount++
} else {
setCount++
}
}()
}
// Simulate connect() writing under mutex
time.Sleep(5 * time.Millisecond)
s.mutex.Lock()
s.connectClient = client
s.mutex.Unlock()
wg.Wait()
assert.Equal(t, 50, nilCount+setCount, "all goroutines should complete without panic")
}
// TestCleanupConnection_ClearsConnectClient validates that cleanupConnection
// properly nils out connectClient.
func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
s := newTestServer()
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
s.connectClient = newDummyConnectClient(context.Background())
s.clientRunning = true
err := s.cleanupConnection()
require.NoError(t, err)
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
}
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
// when connectClient is nil.
func TestCleanState_NilConnectClient(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil // will cause error if it tries to proceed past the nil check
// Should not panic — the nil check should prevent calling Status() on nil
assert.NotPanics(t, func() {
_, _ = s.CleanState(context.Background(), &proto.CleanStateRequest{All: true})
})
}
// TestDeleteState_NilConnectClient validates that DeleteState doesn't panic
// when connectClient is nil.
func TestDeleteState_NilConnectClient(t *testing.T) {
s := newTestServer()
s.connectClient = nil
s.profileManager = nil
assert.NotPanics(t, func() {
_, _ = s.DeleteState(context.Background(), &proto.DeleteStateRequest{All: true})
})
}
// TestDownThenUp_StaleRunningChan documents the known state issue where
// clientRunningChan from a previous connection is already closed, causing
// waitForUp() to return immediately on reconnect.
func TestDownThenUp_StaleRunningChan(t *testing.T) {
s := newTestServer()
// Simulate state after a successful connection
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
close(s.clientRunningChan) // closed when engine started
s.clientGiveUpChan = make(chan struct{})
s.connectClient = newDummyConnectClient(context.Background())
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
// Simulate Down(): cleanupConnection sets connectClient = nil
s.mutex.Lock()
err := s.cleanupConnection()
s.mutex.Unlock()
require.NoError(t, err)
// After cleanup: connectClient is nil, clientRunning still true
// (goroutine hasn't exited yet)
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
s.mutex.Unlock()
// waitForUp() returns immediately due to stale closed clientRunningChan
ctx, ctxCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer ctxCancel()
waitDone := make(chan error, 1)
go func() {
_, err := s.waitForUp(ctx)
waitDone <- err
}()
select {
case err := <-waitDone:
assert.NoError(t, err, "waitForUp returns success on stale channel")
// But connectClient is still nil — this is the stale state issue
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient is nil despite waitForUp success")
s.mutex.Unlock()
case <-time.After(1 * time.Second):
t.Fatal("waitForUp should have returned immediately due to stale closed channel")
}
}
// TestConnectClient_EngineNilOnFreshClient validates that a newly created
// ConnectClient has nil Engine (before Run is called).
func TestConnectClient_EngineNilOnFreshClient(t *testing.T) {
client := newDummyConnectClient(context.Background())
assert.Nil(t, client.Engine(), "engine should be nil on fresh ConnectClient")
}

46
client/server/sleep.go Normal file
View File

@@ -0,0 +1,46 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
type serverAgent struct {
s *Server
}
func (a *serverAgent) Up(ctx context.Context) error {
_, err := a.s.Up(ctx, &proto.UpRequest{})
return err
}
func (a *serverAgent) Down(ctx context.Context) error {
_, err := a.s.Down(ctx, &proto.DownRequest{})
return err
}
func (a *serverAgent) Status() (internal.StatusType, error) {
return internal.CtxGetState(a.s.rootCtx).Status()
}
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
case proto.OSLifecycleRequest_SLEEP:
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -39,7 +39,7 @@ func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*pro
// CleanState handles cleaning of states (performing cleanup operations) // CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) { func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
} }
@@ -82,7 +82,7 @@ func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (
// DeleteState handles deletion of states without cleanup // DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) { func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { if s.connectClient != nil && (s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting) {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
} }

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
@@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows return DefaultDaemonAddrWindows
} }
return DefaultDaemonAddr return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
} }
// DialOptions contains options for SSH connections // DialOptions contains options for SSH connections

View File

@@ -46,8 +46,10 @@ const (
cmdSFTP = "<sftp>" cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>" cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server // DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server.
DefaultJWTMaxTokenAge = 5 * 60 // Set to 10 minutes to accommodate identity providers like Azure Entra ID
// that backdate the iat claim by up to 5 minutes.
DefaultJWTMaxTokenAge = 10 * 60
) )
var ( var (

View File

@@ -323,7 +323,7 @@ type serviceClient struct {
exitNodeMu sync.Mutex exitNodeMu sync.Mutex
mExitNodeItems []menuHandler mExitNodeItems []menuHandler
exitNodeStates []exitNodeState exitNodeRetryCancel context.CancelFunc
mExitNodeDeselectAll *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem
logFile string logFile string
wLoginURL fyne.Window wLoginURL fyne.Window
@@ -924,7 +924,7 @@ func (s *serviceClient) updateStatus() error {
s.mDown.Enable() s.mDown.Enable()
s.mNetworks.Enable() s.mNetworks.Enable()
s.mExitNode.Enable() s.mExitNode.Enable()
go s.updateExitNodes() s.startExitNodeRefresh()
systrayIconState = true systrayIconState = true
case status.Status == string(internal.StatusConnecting): case status.Status == string(internal.StatusConnecting):
s.setConnectingStatus() s.setConnectingStatus()
@@ -985,6 +985,7 @@ func (s *serviceClient) setDisconnectedStatus() {
s.mUp.Enable() s.mUp.Enable()
s.mNetworks.Disable() s.mNetworks.Disable()
s.mExitNode.Disable() s.mExitNode.Disable()
s.cancelExitNodeRetry()
go s.updateExitNodes() go s.updateExitNodes()
} }

View File

@@ -100,8 +100,7 @@ func (h *eventHandler) handleConnectClick() {
func (h *eventHandler) handleDisconnectClick() { func (h *eventHandler) handleDisconnectClick() {
h.client.mDown.Disable() h.client.mDown.Disable()
h.client.cancelExitNodeRetry()
h.client.exitNodeStates = []exitNodeState{}
if h.client.connectCancel != nil { if h.client.connectCancel != nil {
log.Debugf("cancelling ongoing connect operation") log.Debugf("cancelling ongoing connect operation")

View File

@@ -6,7 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"runtime" "runtime"
"slices"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -34,11 +33,6 @@ const (
type filter string type filter string
type exitNodeState struct {
id string
selected bool
}
func (s *serviceClient) showNetworksUI() { func (s *serviceClient) showNetworksUI() {
s.wNetworks = s.app.NewWindow("Networks") s.wNetworks = s.app.NewWindow("Networks")
s.wNetworks.SetOnClosed(s.cancel) s.wNetworks.SetOnClosed(s.cancel)
@@ -335,16 +329,75 @@ func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs,
s.updateNetworks(grid, f) s.updateNetworks(grid, f)
} }
func (s *serviceClient) updateExitNodes() { // startExitNodeRefresh initiates exit node menu refresh after connecting.
// On Windows, TrayOpenedCh is not supported by the systray library, so we use
// a background poller to keep exit nodes in sync while connected.
// On macOS/Linux, TrayOpenedCh handles refreshes on each tray open.
func (s *serviceClient) startExitNodeRefresh() {
s.cancelExitNodeRetry()
if runtime.GOOS == "windows" {
ctx, cancel := context.WithCancel(s.ctx)
s.exitNodeMu.Lock()
s.exitNodeRetryCancel = cancel
s.exitNodeMu.Unlock()
go s.pollExitNodes(ctx)
} else {
go s.updateExitNodes()
}
}
func (s *serviceClient) cancelExitNodeRetry() {
s.exitNodeMu.Lock()
if s.exitNodeRetryCancel != nil {
s.exitNodeRetryCancel()
s.exitNodeRetryCancel = nil
}
s.exitNodeMu.Unlock()
}
// pollExitNodes periodically refreshes exit nodes while connected.
// Uses a short initial interval to catch routes from the management sync,
// then switches to a longer interval for ongoing updates.
func (s *serviceClient) pollExitNodes(ctx context.Context) {
// Initial fast polling to catch routes as they appear after connect.
for i := 0; i < 5; i++ {
if s.updateExitNodes() {
break
}
select {
case <-ctx.Done():
return
case <-time.After(2 * time.Second):
}
}
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.updateExitNodes()
}
}
}
// updateExitNodes fetches exit nodes from the daemon and recreates the menu.
// Returns true if exit nodes were found.
func (s *serviceClient) updateExitNodes() bool {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) log.Errorf("get client: %v", err)
return return false
} }
exitNodes, err := s.getExitNodes(conn) exitNodes, err := s.getExitNodes(conn)
if err != nil { if err != nil {
log.Errorf("get exit nodes: %v", err) log.Errorf("get exit nodes: %v", err)
return return false
} }
s.exitNodeMu.Lock() s.exitNodeMu.Lock()
@@ -354,28 +407,14 @@ func (s *serviceClient) updateExitNodes() {
if len(s.mExitNodeItems) > 0 { if len(s.mExitNodeItems) > 0 {
s.mExitNode.Enable() s.mExitNode.Enable()
} else { return true
s.mExitNode.Disable()
} }
s.mExitNode.Disable()
return false
} }
func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
var exitNodeIDs []exitNodeState
for _, node := range exitNodes {
exitNodeIDs = append(exitNodeIDs, exitNodeState{
id: node.ID,
selected: node.Selected,
})
}
sort.Slice(exitNodeIDs, func(i, j int) bool {
return exitNodeIDs[i].id < exitNodeIDs[j].id
})
if slices.Equal(s.exitNodeStates, exitNodeIDs) {
log.Debug("Exit node menu already up to date")
return
}
for _, node := range s.mExitNodeItems { for _, node := range s.mExitNodeItems {
node.cancel() node.cancel()
node.Hide() node.Hide()
@@ -413,8 +452,6 @@ func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {
go s.handleChecked(ctx, node.ID, menuItem) go s.handleChecked(ctx, node.ID, menuItem)
} }
s.exitNodeStates = exitNodeIDs
if showDeselectAll { if showDeselectAll {
s.mExitNode.AddSeparator() s.mExitNode.AddSeparator()
deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All") deselectAllItem := s.mExitNode.AddSubMenuItem("Deselect All", "Deselect All")

5
combined/Dockerfile Normal file
View File

@@ -0,0 +1,5 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server

Some files were not shown because too many files have changed in this diff Show More