Compare commits

...

63 Commits

Author SHA1 Message Date
Diego Romar
f894da0b11 [client] change handling of EOF on jobStream 2026-02-03 09:02:52 -03:00
Misha Bragin
b20d484972 [docs] Add selfhosting video (#5235) 2026-02-01 16:06:36 +01:00
Vlad
8931293343 [management] run cancelPeerRoutinesWithoutLock in sync (#5234) 2026-02-01 15:44:27 +01:00
Vlad
7b830d8f72 disable sync lim (#5233) 2026-02-01 14:37:00 +01:00
Misha Bragin
3a0cf230a1 Disable local users for a smooth single-idp mode (#5226)
Add LocalAuthDisabled option to embedded IdP configuration

This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external
identity providers (Google, OIDC, etc.).

This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page.

Key changes:

Added LocalAuthDisabled field to EmbeddedIdPConfig
Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth
Validation prevents disabling local auth if no external connectors are configured
Existing local users are preserved when disabled and can login again when re-enabled
Operations are idempotent (disabling already disabled is a no-op)
2026-02-01 14:26:22 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
Zoltan Papp
5333e55a81 Fix WG watcher missing initial handshake (#5213)
Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp.

Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup.
2026-01-29 16:58:10 +01:00
Viktor Liu
81c11df103 [management] Streamline domain validation (#5211) 2026-01-29 13:51:44 +01:00
Viktor Liu
f74bc48d16 [Client] Stop NetBird on firewall init failure (#5208) 2026-01-29 11:05:06 +01:00
Vlad
0169e4540f [management] fix skip of ephemeral peers on deletion (#5206) 2026-01-29 10:58:45 +01:00
Vlad
cead3f38ee [management] fix ephemeral peers being not removed (#5203) 2026-01-28 18:24:12 +01:00
Zoltan Papp
b55262d4a2 [client] Refactor/optimise raw socket headers (#5174)
Pre-create and reuse packet headers to eliminate per-packet allocations.
2026-01-28 15:06:59 +01:00
Zoltan Papp
2248ff392f Remove redundant square bracket trimming in USP endpoint parsing (#5197) 2026-01-27 20:10:59 +01:00
Viktor Liu
06966da012 [client] Support non-PTY no-command interactive SSH sessions (#5093) 2026-01-27 11:05:04 +01:00
Viktor Liu
d4f7df271a [cllient] Don't track ebpf traffic in conntrack (#5166) 2026-01-27 11:04:23 +01:00
Maycon Santos
5299549eb6 [client] Update fyne and add exit menu retry (#5187)
* Update fyne and add exit menu retry

- Fix an extra arrow on macos by updating fyne/systray

* use systray.TrayOpenedCh instead of loop and retry
2026-01-27 09:52:55 +01:00
Misha Bragin
7d791620a6 Add user invite link feature for embedded IdP (#5157) 2026-01-27 09:42:20 +01:00
Zoltan Papp
44ab454a13 [management] Fix peer deletion error handling (#5188)
When a deleted peer tries to reconnect, GetUserIDByPeerKey was returning
Internal error instead of NotFound, causing clients to retry indefinitely
instead of recognizing the unrecoverable PermissionDenied error.

This fix:
1. Updates GetUserIDByPeerKey to properly return NotFound when peer doesn't exist
2. Updates Sync handler to convert NotFound to PermissionDenied with message
   'peer is not registered', matching the behavior of GetAccountIDForPeerKey

Fixes the regression introduced in v0.61.1 where deleted peers would see:
- Before: 'rpc error: code = Internal desc = failed handling request' (retry loop)
- After: 'rpc error: code = PermissionDenied desc = peer is not registered' (exits)
2026-01-26 23:15:34 +01:00
Misha Bragin
11f50d6c38 Include default groups claim in CLI audience (#5186) 2026-01-26 22:26:29 +01:00
Zoltan Papp
05af39a69b [client] Add IPv6 support to UDP WireGuard proxy (#5169)
* Add IPv6 support to UDP WireGuard proxy

Add IPv6 packet header support in UDP raw socket proxy
to handle both IPv4 and IPv6 source addresses.
Refactor error handling in proxy bind implementations
to validate endpoints before acquiring locks.
2026-01-26 14:03:32 +01:00
Viktor Liu
074df56c3d [client] Fix flaky JWT SSH test (#5181) 2026-01-26 09:30:00 +01:00
Maycon Santos
2381e216e4 Fix validator message with warn (#5168) 2026-01-24 17:49:25 +01:00
Zoltan Papp
ded04b7627 [client] Consolidate authentication logic (#5010)
* Consolidate authentication logic

- Moving auth functions from client/internal to client/internal/auth package
- Creating unified auth.Auth client with NewAuth() constructor
- Replacing direct auth function calls with auth client methods
- Refactoring device flow and PKCE flow implementations
- Updating iOS/Android/server code to use new auth client API

* Refactor PKCE auth and login methods

- Remove unnecessary internal package reference in PKCE flow test
- Adjust context assignment placement in iOS and Android login methods
2026-01-23 22:28:32 +01:00
Maycon Santos
67211010f7 [client, gui] fix exit nodes menu on reconnect, remove tooltips (#5167)
* [client, gui] fix exit nodes menu on reconnect

clean s.exitNodeStates when disconnecting

* disable tooltip for exit nodes and settings
2026-01-23 18:39:45 +01:00
Maycon Santos
c61568ceb4 [client] Change default rosenpass log level (#5137)
* Change default rosenpass log level

- Add support to environment configuration
- Change default log level to info

* use .String() for print log level
2026-01-23 18:06:54 +01:00
Vlad
737d6061bf [management] ephemeral peers track on login (#5165) 2026-01-23 18:05:22 +01:00
Zoltan Papp
ee3a67d2d8 [client] Fix/health result in bundle (#5164)
* Add support for optional status refresh callback during debug bundle generation

* Always update wg status

* Remove duplicated wg status call
2026-01-23 17:06:07 +01:00
Viktor Liu
1a32e4c223 [client] Fix IPv4-only in bind proxy (#5154) 2026-01-23 15:15:34 +01:00
Viktor Liu
269d5d1cba [client] Try next DNS upstream on SERVFAIL/REFUSED responses (#5163) 2026-01-23 11:59:52 +01:00
Bethuel Mmbaga
a1de2b8a98 [management] Move activity store encryption to shared crypt package (#5111) 2026-01-22 15:01:13 +03:00
Viktor Liu
d0221a3e72 [client] Add cpu profile to debug bundle (#4700) 2026-01-22 12:24:12 +01:00
Bethuel Mmbaga
8da23daae3 [management] Fix activity event initiator for user group changes (#5152) 2026-01-22 14:18:46 +03:00
Viktor Liu
f86022eace [client] Hide forwarding rules in status when count is zero (#5149)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 10:01:08 +01:00
Viktor Liu
ee54827f94 [client] Add IPv6 support to usersace bind (#5147) 2026-01-22 10:20:43 +08:00
Zoltan Papp
e908dea702 [client] Extend WG watcher for ICE connection too (#5133)
Extend WG watcher for ICE connection too
2026-01-21 10:42:13 +01:00
Maycon Santos
030650a905 [client] Fix RFC 4592 wildcard matching for existing domain names (#5145)
Per RFC 4592 section 2.2.1, wildcards should only match when the queried
name does not exist in the zone. Previously, if host.example.com had an
A record and *.example.com had an AAAA record, querying AAAA for
host.example.com would incorrectly return the wildcard AAAA instead of
NODATA.

Now the resolver checks if the domain exists (with any record type)
before falling back to wildcard matching, returning proper NODATA
responses for existing names without the requested record type.
2026-01-21 08:48:32 +01:00
Misha Bragin
e01998815e [infra] add embedded STUN to getting started (#5141) 2026-01-20 19:01:34 +01:00
Zoltan Papp
07e4a5a23c Fixes profile switching and repeated down/up command failures. (#5142)
When Down() and Up() are called in quick succession, the connectWithRetryRuns goroutine could set ErrResetConnection after Down() had cleared the state, causing the subsequent Up() to fail.

Fix by waiting for the goroutine to exit (via clientGiveUpChan) before Down() returns. Uses a 5-second timeout to prevent RPC timeouts while ensuring the goroutine completes in most cases.
2026-01-20 18:22:37 +01:00
Diego Romar
b3a2992a10 [client/android] - Fix Rosenpass connectivity for Android peers (#5044)
* [client] Add WGConfigurer interface

To allow Rosenpass to work both with kernel
WireGuard via wgctrl (default behavior) and
userspace WireGuard via IPC on Android/iOS
using WGUSPConfigurer

* [client] Remove Rosenpass debug logs

* [client] Return simpler peer configuration in outputKey method

ConfigureDevice, the method previously used in
outputKey via wgClient to update the device's
properties, is now defined in the WGConfigurer
interface and implemented both in kernel_unix and
usp configurers.

PresharedKey datatype was also changed from
boolean to [32]byte to compare it
to the original NetBird PSK, so that Rosenpass
may replace it with its own when necessary.

* [client] Remove unused field

* [client] Replace usage of WGConfigurer

Replaced with preshared key setter interface,
which only defines a method to set / update the preshared key.

Logic has been migrated from rosenpass/netbird_handler to client/iface.

* [client] Use same default peer keepalive value when setting preshared keys

* [client] Store PresharedKeySetter iface in rosenpass manager

To avoid no-op if SetInterface is called before generateConfig

* [client] Add mutex usage in rosenpass netbird handler

* [client] change implementation setting Rosenpass preshared key

Instead of providing a method to configure a device (device/interface.go),
it forwards the new parameters to the configurer (either
kernel_unix.go / usp.go).

This removes dependency on reading FullStats, and makes use of a common
method (buildPresharedKeyConfig in configurer/common.go) to build a
minimal WG config that only sets/updates the PSK.

netbird_handler.go now keeps s list of initializedPeers to choose whether
to set the value of "UpdateOnly" when calling iface.SetPresharedKey.

* [client] Address possible race condition

Between outputKey calls and peer removal; it
checks again if the peer still exists in the
peers map before inserting it in the
initializedPeers map.

* [client] Add psk Rosenpass-initialized check

On client/internal/peer/conn.go, the presharedKey
function would always return the current key
set in wgConfig.presharedKey.

This would eventually overwrite a key set
by Rosenpass if the feature is active.

The purpose here is to set a handler that will
check if a given peer has its psk initialized
by Rosenpass to skip updating the psk
via updatePeer (since it calls presharedKey
method in conn.go).

* Add missing updateOnly flag setup for usp peers

* Change common.go buildPresharedKeyConfig signature

PeerKey datatype changed from string to
wgTypes.Key. Callers are responsible for parsing
a peer key with string datatype.
2026-01-20 13:26:51 -03:00
Maycon Santos
202fa47f2b [client] Add support to wildcard custom records (#5125)
* **New Features**
  * Wildcard DNS fallback for eligible query types (excluding NS/SOA): attempts wildcard records when no exact match, rewrites wildcard names back to the original query, and rotates responses; preserves CNAME resolution.

* **Tests**
  * Vastly expanded coverage for wildcard behaviors, precedence, multi-record round‑robin, multi-type chains, multi-hop and cross-zone scenarios, and edge cases (NXDOMAIN/NODATA, fallthrough).

* **Chores**
  * CI lint config updated to ignore an additional codespell entry.
2026-01-20 17:21:25 +01:00
Misha Bragin
4888021ba6 Add missing activity events to the API response (#5140) 2026-01-20 15:12:22 +01:00
Misha Bragin
a0b0b664b6 Local user password change (embedded IdP) (#5132) 2026-01-20 14:16:42 +01:00
Diego Romar
50da5074e7 [client] change notifyDisconnected call (#5138)
On handleJobStream, when handling error codes 
from receiveJobRequest in the switch-case, 
notifying disconnected in cases where it isn't a 
disconnection breaks connection status reporting 
on mobile peers.

This commit changes it so it isn't called on
Canceled or Unimplemented status codes.
2026-01-20 07:14:33 -03:00
Zoltan Papp
58daa674ef [Management/Client] Trigger debug bundle runs from API/Dashboard (#4592) (#4832)
This PR adds the ability to trigger debug bundle generation remotely from the Management API/Dashboard.
2026-01-19 11:22:16 +01:00
Maycon Santos
245481f33b [client] fix: client/Dockerfile to reduce vulnerabilities (#5119)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-01-16 18:05:41 +01:00
shuuri-labs
b352ab84c0 Feat/quickstart reverse proxy assistant (#5100)
* add external reverse proxy config steps to quickstart script

* remove generated files

* - Remove 'press enter' prompt from post-traefik config since traefik requires no manual config
- Improve npm flow (ask users for docker network, user container names in config)

* fixes for npm flow

* nginx flow fixes

* caddy flow fixes

* Consolidate NPM_NETWORK, NGINX_NETWORK, CADDY_NETWORK into single
EXTERNAL_PROXY_NETWORK variable. Add read_proxy_docker_network()
function that prompts for Docker network for options 2-4 (Nginx,
NPM, Caddy). Generated configs now use container names when a
Docker network is specified.

* fix https for traefik

* fix sonar code smells

* fix sonar smell (add return to render_dashboard_env)

* added tls instructions to nginx flow

* removed unused bind_addr variable from quickstart.sh

* Refactor getting-started.sh for improved maintainability

Break down large functions into focused, single-responsibility components:
- Split init_environment() into 6 initialization functions
- Split print_post_setup_instructions() into 6 proxy-specific functions
- Add section headers for better code organization
- Fix 3 code smell issues (unused bind_addr variables)
- Add TLS certificate documentation for Nginx
- Link reverse proxy names to docs sections

Reduces largest function from 205 to ~90 lines while maintaining
single-file distribution. No functional changes.

* - Remove duplicate network display logic in Traefik instructions
- Use upstream_host instead of bind_addr for NPM forward hostname
- Use upstream_host instead of bind_addr in manual proxy route examples
- Prevents displaying invalid 0.0.0.0 as connection target in setup instructions

* add wait_management_direct to caddy flow to ensure script waits until containers are running/passing healthchecks before reporting 'done!'
2026-01-16 17:42:28 +01:00
ressys1978
3ce5d6a4f8 [management] Add idp timeout env variable (#4647)
Introduced the NETBIRD_IDP_TIMEOUT environment variable to the management service. This allows configuring a timeout for supported IDPs. If the variable is unset or contains an invalid value, a default timeout of 10 seconds is used as a fallback.

This is needed for larger IDP environments where 10s is just not enough time.
2026-01-16 16:23:37 +01:00
Misha Bragin
4c2eb2af73 [management] Skip email_verified if not present (#5118) 2026-01-16 16:01:39 +01:00
Misha Bragin
daf1449174 [client] Remove duplicate audiences check (#5117) 2026-01-16 14:25:02 +02:00
Misha Bragin
1ff7abe909 [management, client] Fix SSH server audience validator (#5105)
* **New Features**
  * SSH server JWT validation now accepts multiple audiences with backward-compatible handling of the previous single-audience setting and a guard ensuring at least one audience is configured.
* **Tests**
  * Test suites updated and new tests added to cover multiple-audience scenarios and compatibility with existing behavior.
* **Other**
  * Startup logging enhanced to report configured audiences for JWT auth.
2026-01-16 12:28:17 +01:00
Bethuel Mmbaga
067c77e49e [management] Add custom dns zones (#4849) 2026-01-16 12:12:05 +03:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
Viktor Liu
520d9c66cf [client] Fix netstack upstream dns and add wasm debug methods (#4648) 2026-01-14 13:56:16 +01:00
Misha Bragin
ff10498a8b Feature/embedded STUN (#5062) 2026-01-14 13:13:30 +01:00
Zoltan Papp
00b747ad5d Handle fallback for invalid loginuid in ui-post-install.sh. (#5099) 2026-01-14 09:53:14 +01:00
Zoltan Papp
d9118eb239 [client] Fix WASM peer connection to lazy peers (#5097)
WASM peers now properly initiate relay connections instead of waiting for offers that lazy peers won't send.
2026-01-13 13:33:15 +01:00
Nima Sadeghifard
94de656fae [misc] Add hiring announcement with link to careers.netbird.io (#5095) 2026-01-12 19:06:28 +01:00
Misha Bragin
37abab8b69 [management] Check config compatibility (#5087)
* Enforce HttpConfig overwrite when embeddedIdp is enabled

* Disable offline_access scope in dashboard by default

* Add group propagation foundation to embedded idp

* Require groups scope in dex config for okt and pocket

* remove offline_access from device default scopes
2026-01-12 17:09:03 +01:00
Viktor Liu
b12c084a50 [client] Fall through dns chain for custom dns zones (#5081) 2026-01-12 13:56:39 +01:00
Viktor Liu
394ad19507 [client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046) 2026-01-12 12:35:38 +01:00
253 changed files with 26147 additions and 4472 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
golangci:
strategy:

View File

@@ -38,6 +38,11 @@
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
@@ -55,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -3,15 +3,7 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
err, _ = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
go urlOpener.OnLoginSuccess()
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -221,21 +219,37 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
cpuProfilingStarted := false
if _, err := client.StartCPUProfile(cmd.Context(), &proto.StartCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to start CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = true
defer func() {
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
}
}
}()
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
} else {
cpuProfilingStarted = false
}
}
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
}
@@ -302,25 +316,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@@ -379,7 +374,8 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
InternalConfig: config,
StatusRecorder: recorder,
SyncResponse: syncResponse,
LogFile: logFilePath,
LogPath: logFilePath,
CPUProfile: nil,
},
debug.BundleConfig{
IncludeSystemInfo: true,

View File

@@ -7,7 +7,6 @@ import (
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
}
jwtToken := ""
@@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
err, _ = authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return nil
@@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -99,17 +99,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
statusOutputString = outputInformationHolder.FullDetailSummary()
case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
statusOutputString, err = outputInformationHolder.JSON()
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
statusOutputString, err = outputInformationHolder.YAML()
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
}
if err != nil {

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
@@ -97,6 +98,8 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -115,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -124,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {

View File

@@ -16,10 +16,12 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
@@ -38,6 +40,7 @@ type Client struct {
setupKey string
jwtToken string
connect *internal.ConnectClient
recorder *peer.Status
}
// Options configures a new Client.
@@ -66,6 +69,8 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
}
// validateCredentials checks that exactly one credential type is provided
@@ -134,6 +139,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -161,26 +167,40 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
if c.connect != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
if err := client.Run(run, ""); err != nil {
clientErr <- err
}
}()
@@ -197,6 +217,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
c.connect = client
c.cancel = cancel
return nil
}
@@ -211,17 +232,23 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted
}
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1)
connect := c.connect
go func() {
done <- c.connect.Stop()
done <- connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
c.connect = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
c.connect = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
@@ -315,6 +342,62 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.cleanupNoTrackChain(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
// Egress rules: match outgoing loopback UDP packets
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
return fmt.Errorf("add output sport notrack rule: %w", err)
}
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
return fmt.Errorf("add output dport notrack rule: %w", err)
}
// Ingress rules: match incoming loopback UDP packets
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
}
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
}
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("create chain: %w", err)
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add prerouting jump rule: %w", err)
}
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
return nil
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("clear and delete chain: %w", err)
}
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -168,6 +168,10 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -48,8 +49,10 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
router *router
aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
}
// Create nftables firewall manager
@@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
@@ -288,7 +295,15 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.Flush()
if err := m.aclManager.Flush(); err != nil {
return err
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
return nil
}
// AddDNATRule adds a DNAT rule
@@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
return fmt.Errorf("notrack chains not initialized")
}
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
loopback := []byte{127, 0, 0, 1}
// Egress rules: match outgoing loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
// Ingress rules: match incoming loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
&expr.Counter{},
&expr.Notrack{},
},
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush notrack rules: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRaw,
})
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawPrerouting,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRaw,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush chain creation: %w", err)
}
return nil
}
func (m *Manager) refreshNoTrackChains() error {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
tableName := getTableName()
for _, c := range chains {
if c.Table.Name != tableName {
continue
}
switch c.Name {
case chainNameRawOutput:
m.notrackOutputChain = c
case chainNameRawPrerouting:
m.notrackPreroutingChain = c
}
}
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -570,6 +570,14 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {

View File

@@ -0,0 +1,169 @@
package bind
import (
"errors"
"net"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
// to the appropriate connection based on the destination address.
// ReadFrom is not used in the hot path - ICEBind receives packets via
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
readFromWarn sync.Once
}
// NewDualStackPacketConn creates a new dual-stack packet connection.
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from the available connection (preferring IPv4).
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
d.readFromWarn.Do(func() {
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
})
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type.
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections.
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection if available,
// otherwise the IPv6 connection.
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections.
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections.
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections.
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -0,0 +1,119 @@
package bind
import (
"net"
"testing"
)
var (
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
payload = make([]byte, 1200)
)
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = conn.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn.Close()
ds := NewDualStackPacketConn(conn, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn.Close()
ds := NewDualStackPacketConn(nil, conn)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv4Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, ipv6Addr)
}
}
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
b.Fatal(err)
}
defer conn4.Close()
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
b.Skipf("IPv6 not available: %v", err)
}
defer conn6.Close()
ds := NewDualStackPacketConn(conn4, conn6)
addrs := []net.Addr{ipv4Addr, ipv6Addr}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ds.WriteTo(payload, addrs[i&1])
}
}

View File

@@ -0,0 +1,191 @@
package bind
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
ipv4Conn := &mockPacketConn{network: "udp4"}
ipv6Conn := &mockPacketConn{network: "udp6"}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
tests := []struct {
name string
addr *net.UDPAddr
wantSocket string
}{
{
name: "IPv4 address",
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 address",
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
wantSocket: "udp6",
},
{
name: "IPv4-mapped IPv6 goes to IPv4",
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv4 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
wantSocket: "udp4",
},
{
name: "IPv6 loopback",
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
wantSocket: "udp6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipv4Conn.writeCount = 0
ipv6Conn.writeCount = 0
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
require.NoError(t, err)
assert.Equal(t, 4, n)
if tt.wantSocket == "udp4" {
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
} else {
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
}
})
}
}
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
// IPv4 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.NoError(t, err)
// IPv6 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv6 connection")
}
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
// IPv6 works
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
require.NoError(t, err)
// IPv4 fails
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
require.Error(t, err)
assert.Contains(t, err.Error(), "no IPv4 connection")
}
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
// only reads from one socket (IPv4 preferred). This is fine because the actual
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
ipv4Conn := &mockPacketConn{
network: "udp4",
readData: []byte("from ipv4"),
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
}
ipv6Conn := &mockPacketConn{
network: "udp6",
readData: []byte("from ipv6"),
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
}
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
buf := make([]byte, 100)
n, addr, err := dualStack.ReadFrom(buf)
require.NoError(t, err)
// reads from IPv4 (preferred) - this is expected behavior
assert.Equal(t, "from ipv4", string(buf[:n]))
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
}
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
tests := []struct {
name string
ipv4 net.PacketConn
ipv6 net.PacketConn
wantAddr net.Addr
}{
{
name: "both available returns IPv4",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv4Addr,
},
{
name: "IPv4 only",
ipv4: &mockPacketConn{localAddr: ipv4Addr},
ipv6: nil,
wantAddr: ipv4Addr,
},
{
name: "IPv6 only",
ipv4: nil,
ipv6: &mockPacketConn{localAddr: ipv6Addr},
wantAddr: ipv6Addr,
},
{
name: "neither returns nil",
ipv4: nil,
ipv6: nil,
wantAddr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
})
}
}
// mock
type mockPacketConn struct {
network string
writeCount int
readData []byte
readAddr net.Addr
localAddr net.Addr
}
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
if m.readData != nil {
return copy(b, m.readData), m.readAddr, nil
}
return 0, nil, nil
}
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
m.writeCount++
return len(b), nil
}
func (m *mockPacketConn) Close() error { return nil }
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@@ -14,7 +14,6 @@ import (
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
@@ -28,22 +27,7 @@ type receiverCreator struct {
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
@@ -73,6 +57,8 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
@@ -118,6 +104,12 @@ func (s *ICEBind) Close() error {
close(s.closedChan)
s.muUDPMux.Lock()
s.ipv4Conn = nil
s.ipv6Conn = nil
s.udpMux = nil
s.muUDPMux.Unlock()
return s.StdNetBind.Close()
}
@@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
// Detect IPv4 vs IPv6 from connection's local address
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
s.ipv4Conn = conn
} else {
s.ipv6Conn = conn
}
s.createOrUpdateMux()
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
@@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
//nolint:staticcheck
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
@@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
continue
}
sizes[i] = msg.N
@@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
}
}
// createOrUpdateMux creates or updates the UDP mux with the available connections.
// Must be called with muUDPMux held.
func (s *ICEBind) createOrUpdateMux() {
var muxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
muxConn = NewDualStackPacketConn(
nbnet.WrapPacketConn(s.ipv4Conn),
nbnet.WrapPacketConn(s.ipv6Conn),
)
case s.ipv4Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
case s.ipv6Conn != nil:
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
default:
return
}
// Don't close the old mux - it doesn't own the underlying connections.
// The sockets are managed by WireGuard's StdNetBind, not by us.
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},
)
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
@@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
s.muUDPMux.Lock()
mux := s.udpMux
s.muUDPMux.Unlock()
if mux != nil {
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
log.Warnf("failed to handle STUN packet: %v", muxErr)
}
}
buffers[i] = []byte{}

View File

@@ -0,0 +1,324 @@
package bind
import (
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
"github.com/pion/transport/v3/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, ipv6Conn := createDualStackConns(t)
defer ipv4Conn.Close()
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
pool := createMsgPool()
// Simulate wireguard-go calling CreateReceiverFn for IPv4
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
require.NotNil(t, ipv4RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
iceBind.muUDPMux.Unlock()
// Simulate wireguard-go calling CreateReceiverFn for IPv6
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
require.NotNil(t, ipv6RecvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer ipv4Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.NotNil(t, iceBind.ipv4Conn)
assert.Nil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
iceBind := setupICEBind(t)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Conn.Close()
rc := receiverCreator{iceBind}
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
require.NotNil(t, recvFn)
iceBind.muUDPMux.Lock()
assert.Nil(t, iceBind.ipv4Conn)
assert.NotNil(t, iceBind.ipv6Conn)
assert.NotNil(t, iceBind.udpMux)
iceBind.muUDPMux.Unlock()
mux, err := iceBind.GetICEMux()
require.NoError(t, err)
require.NotNil(t, mux)
}
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
// with peers on different address families through the same DualStackPacketConn.
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
// two "remote peers" listening on different address families
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
// our local dual-stack connection
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
// send to both peers
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
require.NoError(t, err)
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
require.NoError(t, err)
// verify IPv4 peer got its packet from the IPv4 socket
buf := make([]byte, 100)
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err := ipv4Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv4", string(buf[:n]))
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
// verify IPv6 peer got its packet from the IPv6 socket
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
n, addr, err = ipv6Peer.ReadFrom(buf)
require.NoError(t, err)
assert.Equal(t, "to-ipv6", string(buf[:n]))
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
}
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Peer.Close()
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
if err != nil {
t.Skipf("IPv6 not available: %v", err)
}
defer ipv6Peer.Close()
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
defer ipv4Local.Close()
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
defer ipv6Local.Close()
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
const packetsPerFamily = 500
ipv4Received := make(chan string, packetsPerFamily)
ipv6Received := make(chan string, packetsPerFamily)
startGate := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv4Peer.ReadFrom(buf)
if err != nil {
return
}
ipv4Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
buf := make([]byte, 100)
for i := 0; i < packetsPerFamily; i++ {
n, _, err := ipv6Peer.ReadFrom(buf)
if err != nil {
return
}
ipv6Received <- string(buf[:n])
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
}
}()
wg.Add(1)
go func() {
defer wg.Done()
<-startGate
for i := 0; i < packetsPerFamily; i++ {
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
}
}()
close(startGate)
time.AfterFunc(5*time.Second, func() {
_ = ipv4Peer.SetReadDeadline(time.Now())
_ = ipv6Peer.SetReadDeadline(time.Now())
})
wg.Wait()
close(ipv4Received)
close(ipv6Received)
ipv4Count := 0
for pkt := range ipv4Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
ipv4Count++
}
ipv6Count := 0
for pkt := range ipv6Received {
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
tests := []struct {
name string
network string
addr string
wantIPv4 bool
}{
{"IPv4 any", "udp4", "0.0.0.0:0", true},
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
{"IPv6 any", "udp6", "[::]:0", false},
{"IPv6 loopback", "udp6", "[::1]:0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
require.NoError(t, err)
conn, err := net.ListenUDP(tt.network, addr)
if err != nil {
t.Skipf("%s not available: %v", tt.network, err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
isIPv4 := localAddr.IP.To4() != nil
assert.Equal(t, tt.wantIPv4, isIPv4)
})
}
}
// helpers
func setupICEBind(t *testing.T) *ICEBind {
t.Helper()
transportNet, err := stdnet.NewNet()
require.NoError(t, err)
address := wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
t.Helper()
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
ipv4Conn.Close()
t.Skipf("IPv6 not available: %v", err)
}
return ipv4Conn, ipv6Conn
}
func createMsgPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, 1)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, 0, 40)
}
return &msgs
},
}
}
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
t.Helper()
udpAddr, err := net.ResolveUDPAddr(network, addr)
require.NoError(t, err)
conn, err := net.ListenUDP(network, udpAddr)
require.NoError(t, err)
return conn
}

View File

@@ -3,8 +3,22 @@ package configurer
import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// buildPresharedKeyConfig creates a wgtypes.Config for setting a preshared key on a peer.
// This is a shared helper used by both kernel and userspace configurers.
func buildPresharedKeyConfig(peerKey wgtypes.Key, psk wgtypes.Key, updateOnly bool) wgtypes.Config {
return wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
PublicKey: peerKey,
PresharedKey: &psk,
UpdateOnly: updateOnly,
}},
}
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {

View File

@@ -15,8 +15,6 @@ import (
"github.com/netbirdio/netbird/monotime"
)
var zeroKey wgtypes.Key
type KernelConfigurer struct {
deviceName string
}
@@ -48,6 +46,18 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *KernelConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.configure(cfg)
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -279,7 +289,7 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
PresharedKey: [32]byte(p.PresharedKey),
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint

View File

@@ -22,17 +22,16 @@ import (
)
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -72,6 +71,18 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
// SetPresharedKey sets the preshared key for a peer.
// If updateOnly is true, only updates the existing peer; if false, creates or updates.
func (c *WGUSPConfigurer) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
parsedPeerKey, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
cfg := buildPresharedKeyConfig(parsedPeerKey, psk, updateOnly)
return c.device.IpcSet(toWgUserspaceString(cfg))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
@@ -422,23 +433,19 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
hexKey := hex.EncodeToString(p.PublicKey[:])
sb.WriteString(fmt.Sprintf("public_key=%s\n", hexKey))
if p.Remove {
sb.WriteString("remove=true\n")
}
if p.UpdateOnly {
sb.WriteString("update_only=true\n")
}
if p.PresharedKey != nil {
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
}
if p.Remove {
sb.WriteString("remove=true")
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
if p.Endpoint != nil {
sb.WriteString(fmt.Sprintf("endpoint=%s\n", p.Endpoint.String()))
}
@@ -446,6 +453,14 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
if p.PersistentKeepaliveInterval != nil {
sb.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", int(p.PersistentKeepaliveInterval.Seconds())))
}
if p.ReplaceAllowedIPs {
sb.WriteString("replace_allowed_ips=true\n")
}
for _, aip := range p.AllowedIPs {
sb.WriteString(fmt.Sprintf("allowed_ip=%s\n", aip.String()))
}
}
return sb.String()
}
@@ -543,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
host, portStr, err := net.SplitHostPort(val)
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
@@ -599,7 +614,9 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
if val != "" && val != "0000000000000000000000000000000000000000000000000000000000000000" {
currentPeer.PresharedKey = true
if pskKey, err := hexToWireguardKey(val); err == nil {
currentPeer.PresharedKey = [32]byte(pskKey)
}
}
}
}

View File

@@ -12,7 +12,7 @@ type Peer struct {
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
PresharedKey [32]byte
}
type Stats struct {

View File

@@ -17,6 +17,7 @@ type WGConfigurer interface {
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
Close()
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)

View File

@@ -50,6 +50,7 @@ func ValidateMTU(mtu uint16) error {
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
Free() error
}
@@ -80,6 +81,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetProxyPort returns the proxy port used by the WireGuard proxy.
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
func (w *WGIface) GetProxyPort() uint16 {
return w.wgProxyFactory.GetProxyPort()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
@@ -297,6 +304,19 @@ func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
// SetPresharedKey sets or updates the preshared key for a peer.
// If updateOnly is true, only updates existing peer; if false, creates or updates.
func (w *WGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
return w.configurer.SetPresharedKey(peerKey, psk, updateOnly)
}
func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime)

View File

@@ -114,21 +114,21 @@ func (p *ProxyBind) Pause() {
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.wgCurrentUsed = ep
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
@@ -212,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}

View File

@@ -8,8 +8,6 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
@@ -26,13 +24,10 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
proxyPort int
mtu uint16
ebpfManager ebpfMgr.Manager
@@ -40,7 +35,8 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
lastUsedPort uint16
rawConn net.PacketConn
rawConnIPv4 net.PacketConn
rawConnIPv6 net.PacketConn
conn transport.UDPConn
ctx context.Context
@@ -62,23 +58,39 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
proxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.proxyPort = proxyPort
// Prepare IPv4 raw socket (required)
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
if err != nil {
return err
}
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
// Prepare IPv6 raw socket (optional)
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
if err != nil {
return err
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
}
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
if err != nil {
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
}
if p.rawConnIPv6 != nil {
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
}
}
return err
}
addr := net.UDPAddr{
Port: wgPorxyPort,
Port: proxyPort,
IP: net.ParseIP(loopbackAddr),
}
@@ -94,7 +106,7 @@ func (p *WGEBPFProxy) Listen() error {
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
log.Infof("local wg proxy listening on: %d", proxyPort)
return nil
}
@@ -135,12 +147,25 @@ func (p *WGEBPFProxy) Free() error {
result = multierror.Append(result, err)
}
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
if p.rawConnIPv4 != nil {
if err := p.rawConnIPv4.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if p.rawConnIPv6 != nil {
if err := p.rawConnIPv6.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// GetProxyPort returns the proxy listening port.
func (p *WGEBPFProxy) GetProxyPort() uint16 {
return uint16(p.proxyPort)
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -216,34 +241,3 @@ generatePort:
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,12 +10,89 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool
pausedCond *sync.Cond
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
return err
p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted {
p.isStarted = true
@@ -91,10 +188,32 @@ func (p *ProxyWrapper) Pause() {
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
@@ -136,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
p.pausedCond.L.Unlock()
if err != nil {
@@ -162,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
return n, nil
}
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy {
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
func (w *KernelFactory) GetProxyPort() uint16 {
if w.ebpfProxy == nil {
return 0
}
return w.ebpfProxy.GetProxyPort()
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil

View File

@@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
func (w *USPFactory) GetProxyPort() uint16 {
return 0
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -8,43 +8,87 @@ import (
"os"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/client/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET, true)
}
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET6, false)
}
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
if isIPv4 {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
} else {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
}
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file: %v", closeErr)
}
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
// Close the original file to release the FD (net.FilePacketConn duplicates it)
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
}
return packetConn, nil
}

View File

@@ -0,0 +1,353 @@
//go:build linux && !android
package wgproxy
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
func compareUDPAddr(addr1, addr2 net.Addr) bool {
udpAddr1, ok1 := addr1.(*net.UDPAddr)
udpAddr2, ok2 := addr2.(*net.UDPAddr)
if !ok1 || !ok2 {
return addr1.String() == addr2.String()
}
// Compare IP and Port, ignoring zone
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
func TestRedirectAs_UDP_IPv6(t *testing.T) {
wgPort := 51853
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// testRedirectAs is a helper function that tests the RedirectAs functionality
// It verifies that:
// 1. Initial traffic from relay connection works
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
// 3. Multiple packets are correctly redirected with the new source address
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
t.Helper()
ctx := context.Background()
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
}
defer wgListener4.Close()
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
}
defer wgListener6.Close()
// Determine which listener to use based on the NetBird address IP version
// (this is where initial traffic will come from before RedirectAs is called)
var wgListener *net.UDPConn
if p2pEndpoint.IP.To4() == nil {
wgListener = wgListener6
} else {
wgListener = wgListener4
}
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // Random port
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
// Add TURN connection to proxy
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
// Start the proxy
proxy.Work()
// Phase 1: Test initial relay traffic
msgFromRelay := []byte("hello from relay")
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write to relay server: %v", err)
}
// Set read deadline to avoid hanging
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
buf := make([]byte, 1024)
n, _, err := wgListener4.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read from WireGuard listener: %v", err)
}
if n != len(msgFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
}
if string(buf[:n]) != string(msgFromRelay) {
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
}
// Phase 2: Redirect to p2p endpoint
proxy.RedirectAs(p2pEndpoint)
// Give the proxy a moment to process the redirect
time.Sleep(100 * time.Millisecond)
// Phase 3: Test redirected traffic
redirectedMessages := [][]byte{
[]byte("redirected message 1"),
[]byte("redirected message 2"),
[]byte("redirected message 3"),
}
for i, msg := range redirectedMessages {
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
}
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
}
// Verify message content
if string(buf[:n]) != string(msg) {
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
}
// Verify source address matches p2p endpoint (this is the key test)
// Use compareUDPAddr to ignore IPv6 zone IDs
if !compareUDPAddr(srcAddr, p2pEndpoint) {
t.Errorf("message %d: expected source address %s, got %s",
i+1, p2pEndpoint.String(), srcAddr.String())
}
}
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
ctx := context.Background()
// Create WireGuard listener
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create WireGuard listener: %v", err)
}
defer wgListener.Close()
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
proxy.Work()
// Test switching between multiple endpoints - using addresses in local subnet
endpoints := []*net.UDPAddr{
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
}
for i, endpoint := range endpoints {
proxy.RedirectAs(endpoint)
time.Sleep(100 * time.Millisecond)
msg := []byte("test message")
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
}
buf := make([]byte, 1024)
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
}
if !compareUDPAddr(srcAddr, endpoint) {
t.Errorf("endpoint %d: expected source %s, got %s",
i, endpoint.String(), srcAddr.String())
}
}
}

View File

@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {

View File

@@ -19,37 +19,56 @@ var (
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
localHostNetIPAddrV4 = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
// Create only the raw socket for the address family we need
var rawSocket net.PacketConn
var err error
var localHostAddr *net.IPAddr
if srcAddr.IP.To4() != nil {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
localHostAddr = localHostNetIPAddrV4
} else {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
localHostAddr = localHostNetIPAddrV6
}
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
if closeErr := rawSocket.Close(); closeErr != nil {
log.Warnf("failed to close raw socket: %v", closeErr)
}
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
}
return f, nil
@@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
@@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
// Check if source IP is IPv4 or IPv6
if srcAddr.IP.To4() != nil {
// IPv4
ipv4 := &layers.IPv4{
DstIP: localHostNetIPAddrV4.IP,
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: localHostNetIPAddrV6.IP,
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
err := udpH.SetNetworkLayerForChecksum(networkLayer)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}

View File

@@ -0,0 +1,499 @@
package auth
import (
"context"
"net/url"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
mutex sync.RWMutex
client *mgm.GrpcClient
config *profilemanager.Config
privateKey wgtypes.Key
mgmURL *url.URL
mgmTLSEnabled bool
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The connection is automatically recreated with backoff if it becomes disconnected
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
privateKey: myPrivateKey,
mgmURL: mgmURL,
mgmTLSEnabled: mgmTLSEnabled,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
a.mutex.Lock()
defer a.mutex.Unlock()
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
var supportsSSO bool
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
// Try PKCE flow first
_, err := a.getPKCEFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if PKCE is not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if Device flow is also not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// Neither PKCE nor Device flow is supported
supportsSSO = false
return nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return err
})
return supportsSSO, err
}
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
// This avoids creating a new connection to the management server
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
var flow OAuthFlow
var err error
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
if forceDeviceAuth {
flow, err = a.getDeviceFlow(client)
return err
}
// Try PKCE flow first
flow, err = a.getPKCEFlow(client)
if err != nil {
// If PKCE not supported, try Device flow
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
flow, err = a.getDeviceFlow(client)
return err
}
return err
}
return nil
})
return flow, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
}
needsLogin = false
return err
})
return needsLogin, err
}
// Login attempts to log in or register the client with the management server
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return err, false
}
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
} else if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
isAuthError = false
return nil
})
return err, isAuthError
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
// setSystemInfoFlags sets all configuration flags on the provided system info
func (a *Auth) setSystemInfoFlags(info *system.Info) {
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
}
// reconnect closes the current connection and creates a new one
// It checks if the brokenClient is still the current client before reconnecting
// to avoid multiple threads reconnecting unnecessarily
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
a.mutex.Lock()
defer a.mutex.Unlock()
// Double-check: if client has already been replaced by another thread, skip reconnection
if a.client != brokenClient {
log.Debugf("client already reconnected by another thread, skipping")
return nil
}
// Create new connection FIRST, before closing the old one
// This ensures a.client is never nil, preventing panics in other threads
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
// Keep the old client if reconnection fails
return err
}
// Close old connection AFTER new one is successfully created
oldClient := a.client
a.client = mgmClient
if oldClient != nil {
if err := oldClient.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
return nil
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
// Capture the client BEFORE the operation to ensure we track the correct client
a.mutex.RLock()
currentClient := a.client
a.mutex.RUnlock()
if currentClient == nil {
return status.Errorf(codes.Unavailable, "client is not initialized")
}
// Execute operation with the captured client
err := operation(currentClient)
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection using the client that was actually used
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors, don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// isPermissionDenied checks if the error is a PermissionDenied error.
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
func isPermissionDenied(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.PermissionDenied
}
func isLoginNeeded(err error) bool {
return isAuthenticationError(err)
}
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -26,12 +25,56 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,8 +12,6 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
return pkceFlowInfo, nil
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
return deviceFlowInfo, nil
}

View File

@@ -20,7 +20,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -35,17 +34,67 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
providerConfig PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
excludedRanges := getSystemExcludedPortRanges()
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -9,8 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
availablePort := 65432
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -59,7 +59,6 @@ func NewConnectClient(
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
@@ -71,8 +70,8 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan)
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
}
// RunOnAndroid with main logic on mobile system
@@ -93,7 +92,7 @@ func (c *ConnectClient) RunOnAndroid(
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) RunOniOS(
@@ -111,10 +110,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
return c.run(mobileDependency, nil, "")
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -284,7 +283,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
relayURLs, token := parseRelayInfo(loginResp)
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -420,6 +419,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil
}
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
@@ -459,7 +471,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -494,7 +506,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
ProfileConfig: config,
}
if config.PreSharedKey != "" {

View File

@@ -28,8 +28,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -57,6 +59,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
@@ -223,10 +226,11 @@ type BundleGenerator struct {
internalConfig *profilemanager.Config
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logFile string
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
anonymize bool
clientStatus string
includeSystemInfo bool
logFileCount uint32
@@ -235,7 +239,6 @@ type BundleGenerator struct {
type BundleConfig struct {
Anonymize bool
ClientStatus string
IncludeSystemInfo bool
LogFileCount uint32
}
@@ -244,7 +247,9 @@ type GeneratorDependencies struct {
InternalConfig *profilemanager.Config
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogFile string
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -260,10 +265,11 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
internalConfig: deps.InternalConfig,
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logFile: deps.LogFile,
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
}
@@ -309,13 +315,6 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add status: %w", err)
}
if g.statusRecorder != nil {
status := g.statusRecorder.GetFullStatus()
seedFromStatus(g.anonymizer, &status)
} else {
log.Debugf("no status recorder available for seeding")
}
if err := g.addConfig(); err != nil {
log.Errorf("failed to add config to debug bundle: %v", err)
}
@@ -332,6 +331,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addCPUProfile(); err != nil {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
@@ -352,7 +355,7 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
@@ -401,11 +404,30 @@ func (g *BundleGenerator) addReadme() error {
}
func (g *BundleGenerator) addStatus() error {
if status := g.clientStatus; status != "" {
statusReader := strings.NewReader(status)
if g.statusRecorder != nil {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
if g.refreshStatus != nil {
g.refreshStatus()
}
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
seedFromStatus(g.anonymizer, &fullStatus)
} else {
log.Debugf("no status recorder available for seeding")
}
return nil
}
@@ -535,6 +557,19 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addCPUProfile() error {
if len(g.cpuProfile) == 0 {
return nil
}
reader := bytes.NewReader(g.cpuProfile)
if err := g.addFileToZip(reader, "cpu.prof"); err != nil {
return fmt.Errorf("add CPU profile to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
@@ -710,14 +745,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
}
func (g *BundleGenerator) addLogfile() error {
if g.logFile == "" {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")
return nil
}
logDir := filepath.Dir(g.logFile)
logDir := filepath.Dir(g.logPath)
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}

View File

@@ -0,0 +1,101 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -1,4 +1,4 @@
package server
package debug
import (
"context"
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
fileContent := []byte("test file content")
err := os.WriteFile(file, fileContent, 0640)
require.NoError(t, err)
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
require.NoError(t, err)
id := getURLHash(testURL)
require.Contains(t, key, id+"/")

View File

@@ -60,7 +60,7 @@ func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
if peer.PresharedKey != [32]byte{} {
sb.WriteString(" preshared key: (hidden)\n")
}
}

View File

@@ -1,136 +0,0 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
if zone.NonAuthoritative {
continue
}
for _, record := range zone.Records {

View File

@@ -3,17 +3,21 @@ package dns
import (
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
)
const (
PriorityMgmtCache = 150
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityDNSRoute = 100
PriorityLocal = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100
@@ -43,7 +47,23 @@ type HandlerChain struct {
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
requestID string
shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true
return nil
}
w.response = m
return w.ResponseWriter.WriteMsg(m)
}
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break
}
}
}
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := strings.ToLower(r.Question[0].Name)
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry)
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
if !c.isHandlerMatch(qname, entry) {
continue
}
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
}
continue
}
c.logResponse(logger, chainWriter, qname, startTime)
return
}
// No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname)
logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":

View File

@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false,
shouldMatch: false,
},
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
}
for _, tt := range tests {

View File

@@ -9,8 +9,10 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -38,6 +40,9 @@ const (
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
}
func newHostManager() (*systemConfigurator, error) {
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false
inServerAddressesArray := false
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
}
}
}
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port
dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil
}
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {

View File

@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -1,30 +1,52 @@
package local
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
}
func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
}
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
@@ -48,60 +81,150 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
logger.Debug("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
result := d.lookupRecords(logger, question)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
}
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
if !exists && supportsWildcard(qType) {
testWild := transformDomainToWildcard(string(domainName))
_, exists = d.domains[domain.Domain(testWild)]
}
return exists
}
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock()
records, found := d.records[question]
usingWildcard := false
wildQuestion := transformToWildcard(question)
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
// If the domain exists with any record type, return NODATA instead of wildcard match.
if !found && supportsWildcard(question.Qtype) {
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
records, found = d.records[wildQuestion]
usingWildcard = found
}
}
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
cnameQuestion := dns.Question{
Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
}
return nil
return lookupResult{rcode: dns.RcodeNameError}
}
recordsCopy := slices.Clone(records)
@@ -110,29 +233,229 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
q := question
if usingWildcard {
q = wildQuestion
}
records = d.records[q]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
d.records[q] = records
}
d.mu.Unlock()
}
return recordsCopy
if usingWildcard {
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
func transformToWildcard(question dns.Question) dns.Question {
wildQuestion := question
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
return wildQuestion
}
func transformDomainToWildcard(domain string) string {
s := strings.Split(domain, ".")
s[0] = "*"
return strings.Join(s, ".")
}
func supportsWildcard(queryType uint16) bool {
return queryType != dns.TypeNS && queryType != dns.TypeSOA
}
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
records := make([]dns.RR, len(wildRecords))
for i, record := range wildRecords {
copiedRecord := dns.Copy(record)
copiedRecord.Header().Name = originalName
records[i] = copiedRecord
}
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
wildQuestion := transformToWildcard(cnameQuestion)
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
}
}
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,197 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.localResolver.Update(localRecords)
s.localResolver.Update(localZones)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -616,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config)
}
// registerFallback registers original nameservers as low-priority fallback handlers
// registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
@@ -625,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return
}
@@ -632,9 +632,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,
@@ -659,9 +657,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord
var zones []nbdns.CustomZone
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -675,17 +673,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, localRecords, nil
return muxUpdates, zones, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
@@ -741,9 +742,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
domainGroup.domain,
@@ -924,9 +923,7 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil
}
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@@ -128,7 +133,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -180,8 +185,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
@@ -299,8 +304,8 @@ func TestUpdateDNSServer(t *testing.T) {
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -315,8 +320,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -385,7 +390,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,8 +515,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -2048,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -2,7 +2,6 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -19,8 +18,10 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -70,6 +71,11 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
@@ -113,10 +119,10 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
u.prepareRequest(r)
@@ -125,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if u.tryUpstreamServers(w, r, logger) {
return
ok, failures := u.tryUpstreamServers(w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
if !ok {
u.writeErrorResponse(w, r, logger)
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
@@ -138,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -151,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
}
}
return false
return false, failures
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
@@ -173,53 +185,79 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
}()
if err != nil {
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
return u.handleUpstreamError(err, upstream, startTime)
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
return &upstreamFailure{upstream: upstream, reason: "no response"}
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
return nil
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
return &upstreamFailure{upstream: upstream, reason: err.Error()}
}
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond))
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
reason += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warn(timeoutMsg)
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
failedCount := len(failures)
failureSummary := formatFailures(failures)
if succeeded {
logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
} else {
logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary)
}
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err)
}
}
func formatFailures(failures []upstreamFailure) string {
parts := make([]string, 0, len(failures))
for _, f := range failures {
parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason))
}
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
@@ -414,14 +452,53 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
return nil, err
}
return hex.EncodeToString(bytes)
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
}
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts

View File

@@ -23,9 +23,7 @@ type upstreamResolver struct {
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,

View File

@@ -5,22 +5,23 @@ package dns
import (
"context"
"net/netip"
"runtime"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolver struct {
*upstreamResolverBase
nsNet *netstack.Net
}
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -28,12 +29,23 @@ func newUpstreamResolver(
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
// Similar to iOS logic, we should determine if the DNS server is reachable directly
// or needs to go through the tunnel, and only use netstack when necessary.
// For now, only use netstack on JS platform where direct access is not possible.
if u.nsNet != nil && runtime.GOOS == "js" {
start := time.Now()
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
return reply, time.Since(start), err
}
client := &dns.Client{
Timeout: ClientTimeout,
}

View File

@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip netip.Addr,
net netip.Prefix,
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -37,9 +35,9 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
interfaceName: interfaceName,
lIP: wgIface.Address().IP,
lNet: wgIface.Address().Network,
interfaceName: wgIface.Name(),
}
ios.upstreamClient = ios

View File

@@ -2,13 +2,20 @@ package dns
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
@@ -58,7 +65,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
// Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
@@ -112,6 +119,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}
}
type mockNetstackProvider struct{}
func (m *mockNetstackProvider) Name() string { return "mock" }
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
@@ -123,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
}
type mockUpstreamResolverPerServer struct {
responses map[string]mockUpstreamResponse
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
@@ -174,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
t.Errorf("should be enabled")
}
}
func TestUpstreamResolver_Failover(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
testCases := []struct {
name string
upstream1 mockUpstreamResponse
upstream2 mockUpstreamResponse
expectedRcode int
expectAnswer bool
expectTrySecond bool
}{
{
name: "success on first upstream",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: false,
},
{
name: "SERVFAIL from first should try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "REFUSED from first should try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "NXDOMAIN from first should NOT try second",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeNameError,
expectAnswer: false,
expectTrySecond: false,
},
{
name: "timeout from first should try second",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "no response from first should try second",
upstream1: mockUpstreamResponse{msg: nil},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
expectedRcode: dns.RcodeSuccess,
expectAnswer: true,
expectTrySecond: true,
},
{
name: "both upstreams return SERVFAIL",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "both upstreams timeout",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{err: timeoutErr},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first SERVFAIL then timeout",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
upstream2: mockUpstreamResponse{err: timeoutErr},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first timeout then SERVFAIL",
upstream1: mockUpstreamResponse{err: timeoutErr},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
{
name: "first REFUSED then SERVFAIL",
upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
expectedRcode: dns.RcodeServerFailure,
expectAnswer: false,
expectTrySecond: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var queriedUpstreams []string
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): tc.upstream1,
upstream2.String(): tc.upstream2,
},
rtt: time.Millisecond,
}
trackingClient := &trackingMockClient{
inner: mockClient,
queriedUpstreams: &queriedUpstreams,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, inputMSG)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode")
if tc.expectAnswer {
require.NotEmpty(t, responseMSG.Answer, "expected answer records")
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
}
if tc.expectTrySecond {
assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams")
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
assert.Equal(t, upstream2.String(), queriedUpstreams[1])
} else {
assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream")
assert.Equal(t, upstream1.String(), queriedUpstreams[0])
}
})
}
}
type trackingMockClient struct {
inner *mockUpstreamResolverPerServer
queriedUpstreams *[]string
}
func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) {
*t.queriedUpstreams = append(*t.queriedUpstreams, upstream)
return t.inner.exchange(ctx, upstream, r)
}
func buildMockResponse(rcode int, answer string) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if rcode == dns.RcodeSuccess && answer != "" {
m.Answer = []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP(answer),
},
}
}
return m
}
func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
upstream := netip.MustParseAddrPort("192.0.2.1:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, inputMSG)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
failures []upstreamFailure
expected string
}{
{
name: "empty slice",
failures: []upstreamFailure{},
expected: "",
},
{
name: "single failure",
failures: []upstreamFailure{
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
},
expected: "8.8.8.8:53=SERVFAIL",
},
{
name: "multiple failures",
failures: []upstreamFailure{
{upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"},
{upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"},
},
expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := formatFailures(tc.failures)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -5,6 +5,8 @@ package dns
import (
"net"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -17,4 +19,5 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -1,6 +1,8 @@
package dns
import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -12,5 +14,6 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error)
}

View File

@@ -18,6 +18,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query)
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
}
}
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
result resutil.LookupResult,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
resp.Rcode = result.Rcode
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
log.Warnf(errResolveFailed, domain, err)
logger.Warnf(errResolveFailed, domain, result.Err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
@@ -42,12 +43,14 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -132,6 +135,11 @@ type EngineConfig struct {
LazyConnectionEnabled bool
MTU uint16
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogPath string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -195,7 +203,8 @@ type Engine struct {
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
@@ -211,6 +220,9 @@ type Engine struct {
shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -224,7 +236,18 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -244,6 +267,7 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -312,6 +336,8 @@ func (e *Engine) Stop() error {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
e.close()
// stop flow manager after wg interface is gone
@@ -479,6 +505,15 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// Set up notrack rules immediately after proxy is listening to prevent
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
@@ -500,6 +535,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
@@ -537,9 +573,11 @@ func (e *Engine) createFirewall() error {
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
if err != nil {
return fmt.Errorf("create firewall manager: %w", err)
}
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
}
if err := e.initFirewall(); err != nil {
@@ -585,6 +623,23 @@ func (e *Engine) initFirewall() error {
return nil
}
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
func (e *Engine) setupWGProxyNoTrack() {
if e.firewall == nil {
return
}
proxyPort := e.wgInterface.GetProxyPort()
if proxyPort == 0 {
return
}
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
}
}
func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
@@ -828,9 +883,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if e.persistSyncResponse {
if enabled {
e.syncRespMux.Lock()
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
@@ -960,6 +1024,80 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
defer e.jobExecutorWG.Done()
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
resp := mgmProto.JobResponse{
ID: msg.ID,
Status: mgmProto.JobStatus_failed,
}
switch params := msg.WorkloadParameters.(type) {
case *mgmProto.JobRequest_Bundle:
bundleResult, err := e.handleBundle(params.Bundle)
if err != nil {
log.Errorf("handling bundle: %v", err)
resp.Reason = []byte(err.Error())
return &resp
}
resp.Status = mgmProto.JobStatus_succeeded
resp.WorkloadResults = bundleResult
return &resp
default:
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
return &resp
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Info("stopped receiving jobs from Management Service")
}()
log.Info("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
log.Infof("handle remote debug bundle request: %s", params.String())
syncResponse, err := e.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
bundleDeps := debug.GeneratorDependencies{
InternalConfig: e.config.ProfileConfig,
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
RefreshStatus: func() {
e.RunHealthProbes(true)
},
}
bundleJobParams := debug.BundleConfig{
Anonymize: params.Anonymize,
IncludeSystemInfo: true,
LogFileCount: uint32(params.LogFileCount),
}
waitFor := time.Duration(params.BundleForTime) * time.Minute
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
if err != nil {
return nil, err
}
response := &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: uploadKey,
},
}
return response, nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
@@ -1251,11 +1389,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort,
}
for _, zone := range protoDNSConfig.GetCustomZones() {
protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
@@ -1400,6 +1543,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
if e.rpManager != nil {
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
peerConn.SetRosenpassInitializedPresharedKeyValidator(e.rpManager.IsPresharedKeyInitialized)
}
return peerConn, nil
@@ -1523,6 +1667,7 @@ func (e *Engine) parseNATExternalIPMappings() []string {
func (e *Engine) close() {
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
@@ -1709,7 +1854,7 @@ func (e *Engine) getRosenpassAddr() string {
return ""
}
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// RunHealthProbes executes health checks for Signal, Management, Relay, and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock()
@@ -1723,42 +1868,31 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
stuns := slices.Clone(e.STUNs)
turns := slices.Clone(e.TURNs)
if e.wgInterface != nil {
stats, err := e.wgInterface.GetStats()
if err != nil {
log.Warnf("failed to get wireguard stats: %v", err)
e.syncMsgMux.Unlock()
return false
}
for _, key := range e.peerStore.PeersPubKey() {
// wgStats could be zero value, in which case we just reset the stats
wgStats, ok := stats[key]
if !ok {
continue
}
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
if err := e.statusRecorder.RefreshWireGuardStats(); err != nil {
log.Debugf("failed to refresh WireGuard stats: %v", err)
}
e.syncMsgMux.Unlock()
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
// Skip STUN/TURN probing for JS/WASM as it's not available
relayHealthy := true
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
if runtime.GOOS != "js" {
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy)
@@ -1839,8 +1973,8 @@ func (e *Engine) stopDNSServer() {
// SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
if enabled == e.persistSyncResponse {
return
@@ -1855,20 +1989,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
if !e.persistSyncResponse {
if !enabled {
return nil, errors.New("sync response persistence is disabled")
}
if e.latestSyncResponse == nil {
if latest == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}

View File

@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
@@ -106,6 +107,7 @@ type MockWGIface struct {
GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetProxyPortFunc func() uint16
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
}
@@ -202,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
return m.GetProxyFunc()
}
func (m *MockWGIface) GetProxyPort() uint16 {
if m.GetProxyPortFunc != nil {
return m.GetProxyPortFunc()
}
return 0
}
func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}
@@ -213,6 +222,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
@@ -1599,6 +1612,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
@@ -1622,7 +1636,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
@@ -1631,7 +1645,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -28,6 +28,7 @@ type wgIfaceBase interface {
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error
@@ -42,4 +43,5 @@ type wgIfaceBase interface {
GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}

View File

@@ -1,201 +0,0 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()

View File

@@ -88,8 +88,9 @@ type Conn struct {
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
rosenpassInitializedPresharedKeyValidator func(peerKey string) bool
statusRelay *worker.AtomicWorkerStatus
statusICE *worker.AtomicWorkerStatus
@@ -98,7 +99,10 @@ type Conn struct {
workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
wgWatcher *WGWatcher
wgWatcherWg sync.WaitGroup
wgWatcherCancel context.CancelFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
@@ -126,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
@@ -137,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
}
return conn, nil
@@ -162,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -231,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Infof("close peer connection")
conn.ctxCancel()
conn.workerRelay.DisableWgWatcher()
if conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()
@@ -289,6 +297,13 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler
}
// SetRosenpassInitializedPresharedKeyValidator sets a function to check if Rosenpass has taken over
// PSK management for a peer. When this returns true, presharedKey() returns nil
// to prevent UpdatePeer from overwriting the Rosenpass-managed PSK.
func (conn *Conn) SetRosenpassInitializedPresharedKeyValidator(handler func(peerKey string) bool) {
conn.rosenpassInitializedPresharedKeyValidator = handler
}
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
@@ -366,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
}
@@ -378,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
@@ -423,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
@@ -444,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
}
conn.statusICE.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(),
}
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
}
}
@@ -492,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -500,12 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
@@ -519,7 +525,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.handleRelayDisconnectedLocked()
}
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
func (conn *Conn) handleRelayDisconnectedLocked() {
if conn.ctx.Err() != nil {
return
}
@@ -545,6 +555,8 @@ func (conn *Conn) onRelayDisconnected() {
}
conn.statusRelay.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(),
@@ -563,6 +575,28 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
// Close the active connection based on current priority
switch conn.currentConnPriority {
case conntype.Relay:
conn.workerRelay.CloseConn()
conn.handleRelayDisconnectedLocked()
case conntype.ICEP2P, conntype.ICETurn:
conn.workerICE.Close()
default:
conn.Log.Debugf("No active connection to close on WG timeout")
}
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
@@ -669,10 +703,17 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false
@@ -682,6 +723,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}
func (conn *Conn) enableWgWatcherIfNeeded() {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
}()
}
}
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{
@@ -752,10 +812,24 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return conn.config.WgConfig.PreSharedKey
}
// If Rosenpass has already set a PSK for this peer, return nil to prevent
// UpdatePeer from overwriting the Rosenpass-managed key.
if conn.rosenpassInitializedPresharedKeyValidator != nil && conn.rosenpassInitializedPresharedKeyValidator(conn.config.Key) {
return nil
}
// Use NetBird PSK as the seed for Rosenpass. This same PSK is passed to
// Rosenpass as PeerConfig.PresharedKey, ensuring the derived post-quantum
// key is cryptographically bound to the original secret.
if conn.config.WgConfig.PreSharedKey != nil {
return conn.config.WgConfig.PreSharedKey
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey
return nil
}
return determKey

View File

@@ -284,3 +284,27 @@ func TestConn_presharedKey(t *testing.T) {
})
}
}
func TestConn_presharedKey_RosenpassManaged(t *testing.T) {
conn := Conn{
config: ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
RosenpassConfig: RosenpassConfig{PubKey: []byte("dummykey")},
},
}
// When Rosenpass has already initialized the PSK for this peer,
// presharedKey must return nil to avoid UpdatePeer overwriting it.
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return true }
if k := conn.presharedKey([]byte("remote")); k != nil {
t.Fatalf("expected nil presharedKey when Rosenpass manages PSK, got %v", k)
}
// When Rosenpass hasn't taken over yet, presharedKey should provide
// a non-nil initial key (deterministic or from NetBird PSK).
conn.rosenpassInitializedPresharedKeyValidator = func(peerKey string) bool { return false }
if k := conn.presharedKey([]byte("remote")); k == nil {
t.Fatalf("expected non-nil presharedKey before Rosenpass manages PSK")
}
}

View File

@@ -14,6 +14,7 @@ import (
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -158,6 +159,7 @@ type FullStatus struct {
NSGroupStates []NSGroupState
NumOfForwardingRules int
LazyConnectionEnabled bool
Events []*proto.SystemEvent
}
type StatusChangeSubscription struct {
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
fullStatus.Events = d.GetEventHistory()
return fullStatus
}
@@ -1142,6 +1145,38 @@ func (d *Status) PeersStatus() (*configurer.Stats, error) {
return d.wgIface.FullStats()
}
// RefreshWireGuardStats fetches fresh WireGuard statistics from the interface
// and updates the cached peer states. This ensures accurate handshake times and
// transfer statistics in status reports without running full health probes.
func (d *Status) RefreshWireGuardStats() error {
d.mux.Lock()
defer d.mux.Unlock()
if d.wgIface == nil {
return nil // silently skip if interface not set
}
stats, err := d.wgIface.FullStats()
if err != nil {
return fmt.Errorf("get wireguard stats: %w", err)
}
// Update each peer's WireGuard statistics
for _, peerStats := range stats.Peers {
peerState, ok := d.peers[peerStats.PublicKey]
if !ok {
continue
}
peerState.LastWireguardHandshake = peerStats.LastHandshake
peerState.BytesRx = peerStats.RxBytes
peerState.BytesTx = peerStats.TxBytes
d.peers[peerStats.PublicKey] = peerState
}
return nil
}
type EventQueue struct {
maxSize int
events []*proto.SystemEvent
@@ -1181,3 +1216,97 @@ type EventSubscription struct {
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
return s.events
}
// ToProto converts FullStatus to proto.FullStatus.
func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
if err := fs.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fs.SignalState.URL
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
if err := fs.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
for _, peerState := range fs.Peers {
networks := maps.Keys(peerState.GetRoutes())
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: networks,
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fs.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fs.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
pbFullStatus.Events = fs.Events
return &pbFullStatus
}

View File

@@ -30,10 +30,8 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
enabled bool
muEnabled sync.RWMutex
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
w.log.Debugf("enable WireGuard watcher")
enabledTime := time.Now()
w.enabled = true
w.muEnabled.Unlock()
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
func (w *WGWatcher) DisableWgWatcher() {
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctxCancel == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancel()
w.ctxCancel = nil
// IsEnabled returns true if the WireGuard watcher is currently enabled
func (w *WGWatcher) IsEnabled() bool {
w.muEnabled.RLock()
defer w.muEnabled.RUnlock()
return w.enabled
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
lastHandshake := initialHandshake
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
return
}
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
elapsed := calcElapsed(enabledTime, *handshake)
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
// the current know handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}
// in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false
}
// error handling for handshake time in the future
if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
return nil, false
}
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
}
return wgState.LastHandshake, nil
}
// calcElapsed calculates elapsed time since watcher was enabled.
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
func calcElapsed(enabledTime, handshake time.Time) float64 {
elapsed := handshake.Sub(enabledTime).Seconds()
if elapsed < 0 {
elapsed = 0
}
return elapsed
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"context"
"sync"
"testing"
"time"
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}
func TestWGWatcher_ReEnable(t *testing.T) {
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
}()
cancel()
wg.Wait()
// Re-enable with a new context
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()
go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{}
})
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
case <-time.After(10 * time.Second):
t.Errorf("timeout")
}
watcher.DisableWgWatcher()
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addrString := pair.Remote.Address()
parsed, err := netip.ParseAddr(addrString)
if (err == nil) && (parsed.Is6()) {
addrString = fmt.Sprintf("[%s]", addrString)
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
}
}
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
sessionID := w.SessionID()
stats := agent.GetCandidatePairsStats()
localCandidates, _ := agent.GetLocalCandidates()
remoteCandidates, _ := agent.GetRemoteCandidates()
localMap := make(map[string]ice.Candidate)
for _, c := range localCandidates {
localMap[c.ID()] = c
}
remoteMap := make(map[string]ice.Candidate)
for _, c := range remoteCandidates {
remoteMap[c.ID()] = c
}
for _, stat := range stats {
if stat.State == ice.CandidatePairStateSucceeded {
local, lok := localMap[stat.LocalCandidateID]
remote, rok := remoteMap[stat.RemoteCandidateID]
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
stat.CurrentRoundTripTime*1000)
}
}
}
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
w.logSuccessfulPaths(agent)
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to

View File

@@ -30,11 +30,9 @@ type WorkerRelay struct {
relayLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
config: config,
conn: conn,
relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
}
return r
}
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
})
}
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
}
func (w *WorkerRelay) DisableWgWatcher() {
w.wgWatcher.DisableWgWatcher()
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress()
}
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
}
}
func (w *WorkerRelay) onWGDisconnected() {
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.conn.onRelayDisconnected()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() {
return false
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
}
func (w *WorkerRelay) onRelayClientDisconnected() {
w.wgWatcher.DisableWgWatcher()
go w.conn.onRelayDisconnected()
}

View File

@@ -1,138 +0,0 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -17,6 +17,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
defaultLog = slog.LevelInfo
defaultLogLevelVar = "NB_ROSENPASS_LOG_LEVEL"
)
func hashRosenpassKey(key []byte) string {
hasher := sha256.New()
hasher.Write(key)
@@ -34,6 +39,7 @@ type Manager struct {
server *rp.Server
lock sync.Mutex
port int
wgIface PresharedKeySetter
}
// NewManager creates a new Rosenpass manager
@@ -44,7 +50,7 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
}
rpKeyHash := hashRosenpassKey(public)
log.Debugf("generated new rosenpass key pair with public key %s", rpKeyHash)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
}
@@ -100,7 +106,7 @@ func (m *Manager) removePeer(wireGuardPubKey string) error {
func (m *Manager) generateConfig() (rp.Config, error) {
opts := &slog.HandlerOptions{
Level: slog.LevelDebug,
Level: getLogLevel(),
}
logger := slog.New(slog.NewTextHandler(os.Stdout, opts))
cfg := rp.Config{Logger: logger}
@@ -109,7 +115,13 @@ func (m *Manager) generateConfig() (rp.Config, error) {
cfg.SecretKey = m.ssk
cfg.Peers = []rp.PeerConfig{}
m.rpWgHandler, _ = NewNetbirdHandler(m.preSharedKey, m.ifaceName)
m.lock.Lock()
m.rpWgHandler = NewNetbirdHandler()
if m.wgIface != nil {
m.rpWgHandler.SetInterface(m.wgIface)
}
m.lock.Unlock()
cfg.Handlers = []rp.Handler{m.rpWgHandler}
@@ -126,6 +138,26 @@ func (m *Manager) generateConfig() (rp.Config, error) {
return cfg, nil
}
func getLogLevel() slog.Level {
level, ok := os.LookupEnv(defaultLogLevelVar)
if !ok {
return defaultLog
}
switch strings.ToLower(level) {
case "debug":
return slog.LevelDebug
case "info":
return slog.LevelInfo
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
log.Warnf("unknown log level: %s. Using default %s", level, defaultLog.String())
return defaultLog
}
}
func (m *Manager) OnDisconnected(peerKey string) {
m.lock.Lock()
defer m.lock.Unlock()
@@ -172,6 +204,20 @@ func (m *Manager) Close() error {
return nil
}
// SetInterface sets the WireGuard interface for the rosenpass handler.
// This can be called before or after Run() - the interface will be stored
// and passed to the handler when it's created or updated immediately if
// already running.
func (m *Manager) SetInterface(iface PresharedKeySetter) {
m.lock.Lock()
defer m.lock.Unlock()
m.wgIface = iface
if m.rpWgHandler != nil {
m.rpWgHandler.SetInterface(iface)
}
}
// OnConnected is a handler function that is triggered when a connection to a remote peer establishes
func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) {
m.lock.Lock()
@@ -192,6 +238,20 @@ func (m *Manager) OnConnected(remoteWireGuardKey string, remoteRosenpassPubKey [
}
}
// IsPresharedKeyInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for the given WireGuard peer.
func (m *Manager) IsPresharedKeyInitialized(wireGuardPubKey string) bool {
m.lock.Lock()
defer m.lock.Unlock()
peerID, ok := m.rpPeerIDs[wireGuardPubKey]
if !ok || peerID == nil {
return false
}
return m.rpWgHandler.IsPeerInitialized(*peerID)
}
func findRandomAvailableUDPPort() (int, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {

View File

@@ -1,46 +1,50 @@
package rosenpass
import (
"fmt"
"log/slog"
"sync"
rp "cunicu.li/go-rosenpass"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// PresharedKeySetter is the interface for setting preshared keys on WireGuard peers.
// This minimal interface allows rosenpass to update PSKs without depending on the full WGIface.
type PresharedKeySetter interface {
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
}
type wireGuardPeer struct {
Interface string
PublicKey rp.Key
}
type NetbirdHandler struct {
ifaceName string
client *wgctrl.Client
peers map[rp.PeerID]wireGuardPeer
presharedKey [32]byte
mu sync.Mutex
iface PresharedKeySetter
peers map[rp.PeerID]wireGuardPeer
initializedPeers map[rp.PeerID]bool
}
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
hdlr = &NetbirdHandler{
ifaceName: wgIfaceName,
peers: map[rp.PeerID]wireGuardPeer{},
func NewNetbirdHandler() *NetbirdHandler {
return &NetbirdHandler{
peers: map[rp.PeerID]wireGuardPeer{},
initializedPeers: map[rp.PeerID]bool{},
}
}
if preSharedKey != nil {
hdlr.presharedKey = *preSharedKey
}
if hdlr.client, err = wgctrl.New(); err != nil {
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
}
return hdlr, nil
// SetInterface sets the WireGuard interface for the handler.
// This must be called after the WireGuard interface is created.
func (h *NetbirdHandler) SetInterface(iface PresharedKeySetter) {
h.mu.Lock()
defer h.mu.Unlock()
h.iface = iface
}
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
h.mu.Lock()
defer h.mu.Unlock()
h.peers[pid] = wireGuardPeer{
Interface: intf,
PublicKey: pk,
@@ -48,79 +52,61 @@ func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
}
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
h.mu.Lock()
defer h.mu.Unlock()
delete(h.peers, pid)
delete(h.initializedPeers, pid)
}
// IsPeerInitialized returns true if Rosenpass has completed a handshake
// and set a PSK for this peer.
func (h *NetbirdHandler) IsPeerInitialized(pid rp.PeerID) bool {
h.mu.Lock()
defer h.mu.Unlock()
return h.initializedPeers[pid]
}
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
log.Debug("Handshake complete")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
key, _ := rp.GeneratePresharedKey()
log.Debug("Handshake expired")
h.outputKey(rp.KeyOutputReasonStale, pid, key)
}
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
h.mu.Lock()
iface := h.iface
wg, ok := h.peers[pid]
isInitialized := h.initializedPeers[pid]
h.mu.Unlock()
if iface == nil {
log.Warn("rosenpass: interface not set, cannot update preshared key")
return
}
if !ok {
return
}
device, err := h.client.Device(h.ifaceName)
if err != nil {
log.Errorf("Failed to get WireGuard device: %v", err)
peerKey := wgtypes.Key(wg.PublicKey).String()
pskKey := wgtypes.Key(psk)
// Use updateOnly=true for later rotations (peer already has Rosenpass PSK)
// Use updateOnly=false for first rotation (peer has original/empty PSK)
if err := iface.SetPresharedKey(peerKey, pskKey, isInitialized); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
return
}
config := []wgtypes.PeerConfig{
{
UpdateOnly: true,
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
},
}
for _, peer := range device.Peers {
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
config = []wgtypes.PeerConfig{
{
PublicKey: wgtypes.Key(wg.PublicKey),
PresharedKey: (*wgtypes.Key)(&psk),
Endpoint: peer.Endpoint,
AllowedIPs: peer.AllowedIPs,
},
}
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
Remove: true,
PublicKey: wgtypes.Key(wg.PublicKey),
},
},
})
if err != nil {
slog.Debug("Failed to remove peer")
return
}
}
// Mark peer as isInitialized after the successful first rotation
if !isInitialized {
h.mu.Lock()
if _, exists := h.peers[pid]; exists {
h.initializedPeers[pid] = true
}
}
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
Peers: config,
}); err != nil {
log.Errorf("Failed to apply rosenpass key: %v", err)
h.mu.Unlock()
}
}
func publicKeyEmpty(key wgtypes.Key) bool {
for _, b := range key {
if b != 0 {
return false
}
}
return true
}

View File

@@ -17,12 +17,13 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -37,11 +38,6 @@ type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
@@ -51,7 +47,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
wgInterface wgInterface
wgInterface iface.WGIface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
@@ -219,14 +215,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 {
return
}
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -249,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
@@ -263,32 +253,15 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
if reply == nil {
return
}
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
if err := d.writeMsg(w, reply, logger); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
}
@@ -324,11 +297,15 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
@@ -350,14 +327,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
@@ -370,11 +347,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
logger.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes)
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
}
}
@@ -386,22 +363,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -418,9 +395,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
@@ -432,7 +409,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.addDNATMappings(dnatMappings)
d.addDNATMappings(dnatMappings, logger)
if !d.route.KeepRoute {
// Remove old prefixes
@@ -448,7 +425,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.removeDNATMappings(toRemove)
d.removeDNATMappings(toRemove, logger)
}
// Update domain prefixes using resolved domain as key - store real IPs
@@ -463,14 +440,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
if len(realPrefixes) == 0 {
return
}
@@ -484,9 +461,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -502,7 +479,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
if len(mappings) == 0 {
return
}
@@ -514,9 +491,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -528,12 +505,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes)
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
if _, ok := d.internalDnatFw(); !ok {
return
}
@@ -549,7 +526,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
@@ -560,7 +537,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
return
}
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
// Returns the DNS reply on success, or nil on error (error responses are written internally).
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
startTime := time.Now()
nsNet := d.wgInterface.GetNet()
var reply *dns.Msg
var err error
if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil
}
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
}
if err == nil {
return reply
}
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""

View File

@@ -4,6 +4,8 @@ import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -9,6 +9,8 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms")
}
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name")

View File

@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
return true
}
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
return true // Assume login is required if we can't create auth client
}
defer authClient.Close()
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
log.Errorf("IsLoginRequired: check failed: %v", err)
// If the check fails, assume login is required to be safe
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
if err != nil {
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
return
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
log.Errorf("LoginForMobile: Login failed: %v", err)
return
}

View File

@@ -7,13 +7,8 @@ import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
err, _ = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
// LoginSync performs a synchronous login check without UI interaction
// Used for background VPN connection where user should already be authenticated
func (a *Auth) LoginSync() error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
return fmt.Errorf("not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
if err != nil {
if isAuthError {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
}
return fmt.Errorf("login failed: %v", err)
}
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
}
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
var needsLogin bool
// Create context with device name if provided
ctx := a.ctx
if deviceName != "" {
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
}
// check if we need to generate JWT token
err := a.withBackOff(ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
return
})
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(ctx, func() error {
err := internal.Login(ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
err, isAuthError := authClient.Login(ctx, "", jwtToken)
if err != nil {
if isAuthError {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
}
return fmt.Errorf("login failed: %v", err)
}
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
const authInfoRequestTimeout = 30 * time.Second
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
}
// Use a bounded timeout for the auth info request to prevent indefinite hangs
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
// GetConfigJSON returns the current config as a JSON string.
// This can be used by the caller to persist the config via alternative storage
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).

View File

@@ -0,0 +1,76 @@
package jobexec
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
MaxBundleWaitTime = 60 * time.Minute // maximum wait time for bundle generation (1 hour)
)
var (
ErrJobNotImplemented = errors.New("job not implemented")
)
type Executor struct {
}
func NewExecutor() *Executor {
return &Executor{}
}
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, waitForDuration time.Duration, mgmURL string) (string, error) {
if waitForDuration > MaxBundleWaitTime {
log.Warnf("bundle wait time %v exceeds maximum %v, capping to maximum", waitForDuration, MaxBundleWaitTime)
waitForDuration = MaxBundleWaitTime
}
if waitForDuration > 0 {
if err := waitFor(ctx, waitForDuration); err != nil {
return "", err
}
}
log.Infof("execute debug bundle generation")
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
if err != nil {
log.Errorf("failed to upload debug bundle: %v", err)
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle has been generated successfully")
return key, nil
}
func waitFor(ctx context.Context, duration time.Duration) error {
log.Infof("wait for %v minutes before executing debug bundle", duration.Minutes())
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
log.Infof("wait cancelled: %v", ctx.Err())
return ctx.Err()
}
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v3.21.12
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -2757,7 +2757,6 @@ func (x *ForwardingRulesResponse) GetRules() []*ForwardingRule {
type DebugBundleRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
@@ -2802,13 +2801,6 @@ func (x *DebugBundleRequest) GetAnonymize() bool {
return false
}
func (x *DebugBundleRequest) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
func (x *DebugBundleRequest) GetSystemInfo() bool {
if x != nil {
return x.SystemInfo
@@ -5372,6 +5364,154 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
// StartCPUProfileRequest for starting CPU profiling
type StartCPUProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StartCPUProfileRequest) Reset() {
*x = StartCPUProfileRequest{}
mi := &file_daemon_proto_msgTypes[79]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StartCPUProfileRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartCPUProfileRequest) ProtoMessage() {}
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
}
// StartCPUProfileResponse confirms CPU profiling has started
type StartCPUProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StartCPUProfileResponse) Reset() {
*x = StartCPUProfileResponse{}
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StartCPUProfileResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartCPUProfileResponse) ProtoMessage() {}
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
}
// StopCPUProfileRequest for stopping CPU profiling
type StopCPUProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopCPUProfileRequest) Reset() {
*x = StopCPUProfileRequest{}
mi := &file_daemon_proto_msgTypes[81]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopCPUProfileRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopCPUProfileRequest) ProtoMessage() {}
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[81]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{81}
}
// StopCPUProfileResponse confirms CPU profiling has stopped
type StopCPUProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopCPUProfileResponse) Reset() {
*x = StopCPUProfileResponse{}
mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopCPUProfileResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopCPUProfileResponse) ProtoMessage() {}
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{82}
}
type InstallerResultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -5380,7 +5520,7 @@ type InstallerResultRequest struct {
func (x *InstallerResultRequest) Reset() {
*x = InstallerResultRequest{}
mi := &file_daemon_proto_msgTypes[79]
mi := &file_daemon_proto_msgTypes[83]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5392,7 +5532,7 @@ func (x *InstallerResultRequest) String() string {
func (*InstallerResultRequest) ProtoMessage() {}
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
mi := &file_daemon_proto_msgTypes[83]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5405,7 +5545,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
return file_daemon_proto_rawDescGZIP(), []int{83}
}
type InstallerResultResponse struct {
@@ -5418,7 +5558,7 @@ type InstallerResultResponse struct {
func (x *InstallerResultResponse) Reset() {
*x = InstallerResultResponse{}
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[84]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5430,7 +5570,7 @@ func (x *InstallerResultResponse) String() string {
func (*InstallerResultResponse) ProtoMessage() {}
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[84]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5443,7 +5583,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
return file_daemon_proto_rawDescGZIP(), []int{84}
}
func (x *InstallerResultResponse) GetSuccess() bool {
@@ -5470,7 +5610,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[86]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5482,7 +5622,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[86]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5773,10 +5913,9 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x94\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x1e\n" +
"\n" +
"systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" +
@@ -6003,6 +6142,10 @@ const file_daemon_proto_rawDesc = "" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
"\x16StartCPUProfileRequest\"\x19\n" +
"\x17StartCPUProfileResponse\"\x17\n" +
"\x15StopCPUProfileRequest\"\x18\n" +
"\x16StopCPUProfileResponse\"\x18\n" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
@@ -6015,7 +6158,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xb4\x13\n" +
"\x05TRACE\x10\a2\xdd\x14\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6050,7 +6193,9 @@ const file_daemon_proto_rawDesc = "" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12T\n" +
"\x0fStartCPUProfile\x12\x1e.daemon.StartCPUProfileRequest\x1a\x1f.daemon.StartCPUProfileResponse\"\x00\x12Q\n" +
"\x0eStopCPUProfile\x12\x1d.daemon.StopCPUProfileRequest\x1a\x1e.daemon.StopCPUProfileResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
@@ -6067,7 +6212,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 88)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6152,21 +6297,25 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
nil, // 85: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
nil, // 87: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
nil, // 89: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 90: daemon.PortInfo.Range
nil, // 91: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 92: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 93: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
92, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
93, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
93, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
92, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6177,8 +6326,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
89, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
90, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6189,10 +6338,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
93, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
91, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
92, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6226,43 +6375,47 @@ var file_daemon_proto_depIdxs = []int32{
77, // 62: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
67, // [67:100] is the sub-list for method output_type
34, // [34:67] is the sub-list for method input_type
83, // 65: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
85, // 66: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
5, // 67: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
87, // 68: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 69: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 70: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 71: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 72: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 73: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 74: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 75: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 76: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 77: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 78: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 79: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 80: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 81: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 82: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 83: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 84: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 85: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 86: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 87: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 88: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 89: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 90: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 91: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 92: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 93: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 94: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 95: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 96: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 97: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 98: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 99: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
84, // 100: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
86, // 101: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
6, // 102: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
88, // 103: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
69, // [69:104] is the sub-list for method output_type
34, // [34:69] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6292,7 +6445,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 84,
NumMessages: 88,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -94,6 +94,12 @@ service DaemonService {
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
// StartCPUProfile starts CPU profiling in the daemon
rpc StartCPUProfile(StartCPUProfileRequest) returns (StartCPUProfileResponse) {}
// StopCPUProfile stops CPU profiling in the daemon
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
@@ -455,7 +461,6 @@ message ForwardingRulesResponse {
// DebugBundler
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
bool systemInfo = 3;
string uploadURL = 4;
uint32 logFileCount = 5;
@@ -777,6 +782,18 @@ message WaitJWTTokenResponse {
int64 expiresIn = 3;
}
// StartCPUProfileRequest for starting CPU profiling
message StartCPUProfileRequest {}
// StartCPUProfileResponse confirms CPU profiling has started
message StartCPUProfileResponse {}
// StopCPUProfileRequest for stopping CPU profiling
message StopCPUProfileRequest {}
// StopCPUProfileResponse confirms CPU profiling has stopped
message StopCPUProfileResponse {}
message InstallerResultRequest {
}

View File

@@ -70,6 +70,10 @@ type DaemonServiceClient interface {
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
// StartCPUProfile starts CPU profiling in the daemon
StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error)
// StopCPUProfile stops CPU profiling in the daemon
StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
@@ -384,6 +388,24 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
return out, nil
}
func (c *daemonServiceClient) StartCPUProfile(ctx context.Context, in *StartCPUProfileRequest, opts ...grpc.CallOption) (*StartCPUProfileResponse, error) {
out := new(StartCPUProfileResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StartCPUProfile", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) StopCPUProfile(ctx context.Context, in *StopCPUProfileRequest, opts ...grpc.CallOption) (*StopCPUProfileResponse, error) {
out := new(StopCPUProfileResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/StopCPUProfile", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
out := new(OSLifecycleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
@@ -458,6 +480,10 @@ type DaemonServiceServer interface {
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
// StartCPUProfile starts CPU profiling in the daemon
StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error)
// StopCPUProfile stops CPU profiling in the daemon
StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
@@ -560,6 +586,12 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) StartCPUProfile(context.Context, *StartCPUProfileRequest) (*StartCPUProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StartCPUProfile not implemented")
}
func (UnimplementedDaemonServiceServer) StopCPUProfile(context.Context, *StopCPUProfileRequest) (*StopCPUProfileResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StopCPUProfile not implemented")
}
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
@@ -1140,6 +1172,42 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler)
}
func _DaemonService_StartCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StartCPUProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).StartCPUProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/StartCPUProfile",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).StartCPUProfile(ctx, req.(*StartCPUProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_StopCPUProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StopCPUProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).StopCPUProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/StopCPUProfile",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).StopCPUProfile(ctx, req.(*StopCPUProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(OSLifecycleRequest)
if err := dec(in); err != nil {
@@ -1303,6 +1371,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
{
MethodName: "StartCPUProfile",
Handler: _DaemonService_StartCPUProfile_Handler,
},
{
MethodName: "StopCPUProfile",
Handler: _DaemonService_StopCPUProfile_Handler,
},
{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,

View File

@@ -3,25 +3,19 @@
package server
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"runtime/pprof"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -32,16 +26,37 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
log.Warnf("failed to get latest sync response: %v", err)
}
var cpuProfileData []byte
if s.cpuProfileBuf != nil && !s.cpuProfiling {
cpuProfileData = s.cpuProfileBuf.Bytes()
defer func() {
s.cpuProfileBuf = nil
}()
}
// Prepare refresh callback for health probes
var refreshStatus func()
if s.connectClient != nil {
engine := s.connectClient.Engine()
if engine != nil {
refreshStatus = func() {
log.Debug("refreshing system health status for debug bundle")
engine.RunHealthProbes(true)
}
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: s.config,
StatusRecorder: s.statusRecorder,
SyncResponse: syncResponse,
LogFile: s.logFile,
LogPath: s.logFile,
CPUProfile: cpuProfileData,
RefreshStatus: refreshStatus,
},
debug.BundleConfig{
Anonymize: req.GetAnonymize(),
ClientStatus: req.GetStatus(),
IncludeSystemInfo: req.GetSystemInfo(),
LogFileCount: req.GetLogFileCount(),
},
@@ -55,7 +70,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil
}
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
@@ -66,92 +81,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
}
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}
// GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()
@@ -173,20 +102,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level)
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
if s.connectClient != nil {
s.connectClient.SetLogLevel(level)
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String())
@@ -215,3 +133,43 @@ func (s *Server) getLatestSyncResponse() (*mgmProto.SyncResponse, error) {
return cClient.GetLatestSyncResponse()
}
// StartCPUProfile starts CPU profiling in the daemon.
func (s *Server) StartCPUProfile(_ context.Context, _ *proto.StartCPUProfileRequest) (*proto.StartCPUProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.cpuProfiling {
return nil, fmt.Errorf("CPU profiling already in progress")
}
s.cpuProfileBuf = &bytes.Buffer{}
s.cpuProfiling = true
if err := pprof.StartCPUProfile(s.cpuProfileBuf); err != nil {
s.cpuProfileBuf = nil
s.cpuProfiling = false
return nil, fmt.Errorf("start CPU profile: %w", err)
}
log.Info("CPU profiling started")
return &proto.StartCPUProfileResponse{}, nil
}
// StopCPUProfile stops CPU profiling in the daemon.
func (s *Server) StopCPUProfile(_ context.Context, _ *proto.StopCPUProfileRequest) (*proto.StopCPUProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.cpuProfiling {
return nil, fmt.Errorf("CPU profiling not in progress")
}
pprof.StopCPUProfile()
s.cpuProfiling = false
if s.cpuProfileBuf != nil {
log.Infof("CPU profiling stopped, captured %d bytes", s.cpuProfileBuf.Len())
}
return &proto.StopCPUProfileResponse{}, nil
}

View File

@@ -1,8 +1,6 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
}
}
}
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
events := s.statusRecorder.GetEventHistory()
return &proto.GetEventsResponse{Events: events}, nil
}

View File

@@ -1,6 +1,7 @@
package server
import (
"bytes"
"context"
"errors"
"fmt"
@@ -13,15 +14,11 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -70,7 +67,7 @@ type Server struct {
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
clientGiveUpChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
connectClient *internal.ConnectClient
@@ -81,6 +78,9 @@ type Server struct {
persistSyncResponse bool
isSessionActive atomic.Bool
cpuProfileBuf *bytes.Buffer
cpuProfiling bool
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
@@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType
err := internal.Login(ctx, s.config, setupKey, jwtToken)
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
if err != nil {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Errorf("failed to create auth client: %v", err)
return internal.StatusLoginFailed, err
}
defer authClient.Close()
var status internal.StatusType
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
if isAuthError {
log.Warnf("failed login: %v", err)
status = internal.StatusNeedsLogin
} else {
@@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel()
}
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
waitCTX, cancel := context.WithCancel(ctx)
defer cancel()
s.mutex.Lock()
@@ -796,9 +802,11 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
// Down engine work in the daemon.
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
giveUpChan := s.clientGiveUpChan
if err := s.cleanupConnection(); err != nil {
s.mutex.Unlock()
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
@@ -807,6 +815,20 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
s.mutex.Unlock()
// Wait for the connectWithRetryRuns goroutine to finish with a short timeout.
// This prevents the goroutine from setting ErrResetConnection after Down() returns.
// The giveUpChan is closed at the end of connectWithRetryRuns.
if giveUpChan != nil {
select {
case <-giveUpChan:
log.Debugf("client goroutine finished successfully")
case <-time.After(5 * time.Second):
log.Warnf("timeout waiting for client goroutine to finish, proceeding anyway")
}
}
return &proto.DownResponse{}, nil
}
@@ -1067,11 +1089,9 @@ func (s *Server) Status(
if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1313,6 +1333,10 @@ func (s *Server) runProbes(waitForProbeResult bool) {
if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
} else {
if err := s.statusRecorder.RefreshWireGuardStats(); err != nil {
log.Debugf("failed to refresh WireGuard stats: %v", err)
}
}
}
@@ -1526,7 +1550,7 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
if err := s.connectClient.Run(runningChan, s.logFile); err != nil {
return err
}
return nil
@@ -1600,94 +1624,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -306,6 +307,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -317,7 +320,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
@@ -326,7 +329,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
}
defer func() { _ = serverSession.Close() }()
<-session.Context().Done()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
if err := serverSession.Shell(); err != nil {
log.Debugf("start shell: %v", err)
return
}
done := make(chan error, 1)
go func() {
done <- serverSession.Wait()
}()
select {
case <-session.Context().Done():
return
case err := <-done:
if err != nil {
log.Debugf("shell session: %v", err)
p.handleProxyExitCode(session, err)
}
}
}

View File

@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}

View File

@@ -12,8 +12,8 @@ import (
log "github.com/sirupsen/logrus"
)
// handleCommand executes an SSH command with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
// handleExecution executes an SSH command or shell with privilege validation
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
hasPty := winCh != nil
commandType := "command"
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty)
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
logger.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, cleanup, nil
}
cmd.Env = s.prepareCommandEnv(localUser, session)
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, func() {}, nil
}

View File

@@ -15,17 +15,17 @@ import (
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
return nil
}

Some files were not shown because too many files have changed in this diff Show More