Compare commits

...

39 Commits

Author SHA1 Message Date
Hakan Sariman
2db23a42dc Add DNS configuration snapshot and per-domain statistics tracking 2025-09-09 17:09:13 +07:00
Hakan Sariman
c2822eebb0 [client] Enhance logging for peer disconnection events 2025-09-09 15:02:16 +07:00
Hakan Sariman
5b246e0a08 debug dns 2025-09-09 14:48:39 +07:00
Zoltan Papp
7aef0f67df [client] Implement environment variable handling for Android (#4440)
Some features can only be manipulated via environment variables. With this PR, environment variables can be managed from Android.
2025-09-08 18:42:42 +02:00
Maycon Santos
dba7ef667d [misc] Remove aur support and start service on ostree (#4461)
* Remove aur support and start service on ostree

The aur installation was adding many packages and installing more than just the client. For now is best to remove it and rely on binary install

Some users complained about ostree installation not starting the client, we add two explicit commands to it

* use  ${SUDO}

* fix if closure
2025-09-08 15:03:56 +02:00
Zoltan Papp
69d87343d2 [client] Debug information for connection (#4439)
Improve logging

Print the exact time when the first WireGuard handshake occurs
Print the steps for gathering system information
2025-09-08 14:51:34 +02:00
Bethuel Mmbaga
5113c70943 [management] Extends integration and peers manager (#4450) 2025-09-06 13:13:49 +03:00
Zoltan Papp
ad8fcda67b [client] Move some sys info to static place (#4446)
This PR refactors the system information collection code by moving static system information gathering to a dedicated location and separating platform-specific implementations. The primary goal is to improve code organization and maintainability by centralizing static info collection logic.

Key changes:
- Centralized static info collection into dedicated files with platform-specific implementations
- Moved `StaticInfo` struct definition to the main static_info.go file
- Added async initialization function `UpdateStaticInfoAsync()` across all platforms
2025-09-06 10:49:28 +02:00
Pascal Fischer
d33f88df82 [management] only allow user devices to be expired (#4445) 2025-09-05 18:11:23 +02:00
Zoltan Papp
786ca6fc79 Do not block Offer processing from relay worker (#4435)
- do not miss ICE offers when relay worker busy
- close p2p connection before recreate agent
2025-09-05 11:02:29 +02:00
Diego Romar
dfebdf1444 [internal] Add missing assignment of iFaceDiscover when netstack is disabled (#4444)
The internal updateInterfaces() function expects iFaceDiscover to not
be nil
2025-09-04 23:00:10 +02:00
Bethuel Mmbaga
a8dcff69c2 [management] Add peers manager to integrations (#4405) 2025-09-04 23:07:03 +03:00
Viktor Liu
71e944fa57 [relay] Let relay accept any origin (#4426) 2025-09-01 19:51:06 +02:00
Maycon Santos
d39fcfd62a [management] Add user approval (#4411)
This PR adds user approval functionality to the management system, allowing administrators to manually approve new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin.

Adds UserApprovalRequired setting to control manual user approval requirement
Introduces user approval and rejection endpoints with corresponding business logic
Prevents pending approval users from adding peers or logging in
2025-09-01 18:00:45 +02:00
Zoltan Papp
21368b38d9 [client] Update Pion ICE to the latest version (#4388)
- Update Pion version
- Update protobuf version
2025-09-01 10:42:01 +02:00
Maycon Santos
d817584f52 [misc] fix Windows client and management bench tests (#4424)
Windows tests had too many directories, causing issues to the payload via psexec.

Also migrated all checked benchmarks to send data to grafana.
2025-08-31 17:19:56 +02:00
Pascal Fischer
4d3dc3475d [management] remove duplicated removal of groups on peer delete (#4421) 2025-08-30 12:47:13 +02:00
Pascal Fischer
6fc50a438f [management] remove withContext from store methods (#4422) 2025-08-30 12:46:54 +02:00
Vlad
149559a06b [management] login filter to fix multiple peers connected with the same pub key (#3986) 2025-08-29 19:48:40 +02:00
Pascal Fischer
e14c6de203 [management] fix ephemeral flag on peer batch response (#4420) 2025-08-29 17:41:20 +02:00
Viktor Liu
d4c067f0af [client] Don't deactivate upstream resolvers on failure (#4128) 2025-08-29 17:40:05 +02:00
Pascal Fischer
dbefa8bd9f [management] remove lock and continue user update on failure (#4410) 2025-08-28 17:50:12 +02:00
Pascal Fischer
4fd10b9447 [management] split high latency grpc metrics (#4408) 2025-08-28 13:25:40 +02:00
Viktor Liu
aa595c3073 [client] Fix shared sock buffer allocation (#4409) 2025-08-28 13:25:16 +02:00
Vlad
99bd34c02a [signal] fix goroutines and memory leak on forward messages between peers (#3896) 2025-08-27 19:30:49 +03:00
Krzysztof Nazarewski (kdn)
7ce5507c05 [client] fix darwin dns always throwing err (#4403)
* fix: dns/host_darwin.go was missing if err != nil before throwing error
2025-08-27 09:59:39 +02:00
Pascal Fischer
0320bb7b35 [management] Report sync duration and login duration by accountID (#4406) 2025-08-26 22:32:12 +02:00
Viktor Liu
f063866ce8 [client] Add flag to configure MTU (#4213) 2025-08-26 16:00:14 +02:00
plusls
9f84165763 [client] Add netstack support for Android cli (#4319) 2025-08-26 15:40:01 +02:00
Pascal Fischer
3488a516c9 [management] Move increment network serial as last step of each transaction (#4397) 2025-08-25 17:27:07 +02:00
Pascal Fischer
5e273c121a [management] Remove store locks 3 (#4390) 2025-08-21 20:47:28 +02:00
Bethuel Mmbaga
968d95698e [management] Bump github.com/golang-jwt/jwt from 3.2.2+incompatible to 5.3.0 (#4375) 2025-08-21 15:02:51 +03:00
Pascal Fischer
28bef26537 [management] Remove Store Locks 2 (#4385) 2025-08-21 12:23:49 +02:00
Pascal Fischer
0d2845ea31 [management] optimize proxy network map (#4324) 2025-08-20 19:04:19 +02:00
Zoltan Papp
f425870c8e [client] Avoid duplicated agent close (#4383) 2025-08-20 18:50:51 +02:00
Pascal Fischer
f9d64a06c2 [management] Remove all store locks from grpc side (#4374) 2025-08-20 12:41:14 +02:00
hakansa
86555c44f7 refactor doc workflow (#4373)
refactor doc workflow (#4373)
2025-08-20 10:59:32 +03:00
Bastien Jeannelle
48792c64cd [misc] Fix confusing comment (#4376) 2025-08-20 00:12:00 +02:00
hakansa
533d93eb17 [management,client] Feat/exit node auto apply (#4272)
[management,client] Feat/exit node auto apply (#4272)
2025-08-19 18:19:24 +03:00
200 changed files with 6344 additions and 2053 deletions

View File

@@ -16,19 +16,29 @@ jobs:
steps: steps:
- name: Read PR body - name: Read PR body
id: body id: body
shell: bash
run: | run: |
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH") set -euo pipefail
echo "body<<EOF" >> $GITHUB_OUTPUT BODY_B64=$(jq -r '.pull_request.body // "" | @base64' "$GITHUB_EVENT_PATH")
echo "$BODY" >> $GITHUB_OUTPUT {
echo "EOF" >> $GITHUB_OUTPUT echo "body_b64=$BODY_B64"
} >> "$GITHUB_OUTPUT"
- name: Validate checkbox selection - name: Validate checkbox selection
id: validate id: validate
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: | run: |
body='${{ steps.body.outputs.body }}' set -euo pipefail
if ! body="$(printf '%s' "$BODY_B64" | base64 -d)"; then
echo "::error::Failed to decode PR body from base64. Data may be corrupted or missing."
exit 1
fi
added_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*I added/updated documentation' | wc -l | tr -d '[:space:]' || true)
noneed_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*Documentation is \*\*not needed\*\*' | wc -l | tr -d '[:space:]' || true)
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'." echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
@@ -41,30 +51,35 @@ jobs:
fi fi
if [ "$added_checked" -eq 1 ]; then if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> $GITHUB_OUTPUT echo "mode=added" >> "$GITHUB_OUTPUT"
else else
echo "mode=noneed" >> $GITHUB_OUTPUT echo "mode=noneed" >> "$GITHUB_OUTPUT"
fi fi
- name: Extract docs PR URL (when 'docs added') - name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added' if: steps.validate.outputs.mode == 'added'
id: extract id: extract
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: | run: |
body='${{ steps.body.outputs.body }}' set -euo pipefail
body="$(printf '%s' "$BODY_B64" | base64 -d)"
# Strictly require HTTPS and that it's a PR in netbirdio/docs # Strictly require HTTPS and that it's a PR in netbirdio/docs
# Examples accepted: # e.g., https://github.com/netbirdio/docs/pull/1234
# https://github.com/netbirdio/docs/pull/1234 url="$(printf '%s' "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)"
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
if [ -z "$url" ]; then if [ -z "${url:-}" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)." echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1 exit 1
fi fi
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#') pr_number="$(printf '%s' "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')"
echo "url=$url" >> $GITHUB_OUTPUT {
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT echo "url=$url"
echo "pr_number=$pr_number"
} >> "$GITHUB_OUTPUT"
- name: Verify docs PR exists (and is open or merged) - name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added' if: steps.validate.outputs.mode == 'added'

View File

@@ -382,6 +382,32 @@ jobs:
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -428,9 +454,10 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \ CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -521,7 +548,7 @@ jobs:
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/server/http/...
api_integration_test: api_integration_test:
name: "Management / Integration" name: "Management / Integration"
@@ -571,4 +598,4 @@ jobs:
CI=true \ CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/... -timeout 20m ./management/server/http/...

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"os"
"slices" "slices"
"sync" "sync"
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps. // In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() { func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener() c.recorder.RemoveConnectionListener()
} }
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -0,0 +1,32 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -388,12 +388,12 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
} }
func init() { func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle") debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 10, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle") forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 10, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")

View File

@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
// update host's static platform and system information // update host's static platform and system information
system.UpdateStaticInfo() system.UpdateStaticInfoAsync()
configFilePath, err := activeProf.FilePath() configFilePath, err := activeProf.FilePath()
if err != nil { if err != nil {

View File

@@ -39,6 +39,7 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection" enableLazyConnectionFlag = "enable-lazy-connection"
mtuFlag = "mtu"
) )
var ( var (
@@ -72,6 +73,7 @@ var (
anonymizeFlag bool anonymizeFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
lazyConnEnabled bool lazyConnEnabled bool
mtu uint16
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool

View File

@@ -54,6 +54,7 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
cmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name") t.Setenv("NB_INTERFACE_NAME", "test-name")

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint log.Info("starting NetBird service") //nolint
// Collect static system and platform information // Collect static system and platform information
system.UpdateStaticInfo() system.UpdateStaticInfoAsync()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()

View File

@@ -9,29 +9,26 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server" sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
) )
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
@@ -90,15 +87,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().

View File

@@ -63,6 +63,7 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
upCmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
@@ -357,6 +358,11 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.WireguardPort = &p req.WireguardPort = &p
} }
if cmd.Flag(mtuFlag).Changed {
m := int64(mtu)
req.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed { if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &networkMonitor req.NetworkMonitor = &networkMonitor
} }
@@ -436,6 +442,13 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.WireguardPort = &p ic.WireguardPort = &p
} }
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
ic.MTU = &mtu
}
if cmd.Flag(networkMonitorFlag).Changed { if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor ic.NetworkMonitor = &networkMonitor
} }
@@ -533,6 +546,14 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.WireguardPort = &wp loginRequest.WireguardPort = &wp
} }
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
m := int64(mtu)
loginRequest.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed { if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor loginRequest.NetworkMonitor = &networkMonitor
} }

View File

@@ -8,7 +8,7 @@ import (
"runtime" "runtime"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@@ -56,10 +56,11 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -69,6 +70,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
mtu: mtu,
address: address, address: address,
activityRecorder: NewActivityRecorder(), activityRecorder: NewActivityRecorder(),
} }
@@ -80,6 +82,10 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
return ib return ib
} }
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false s.closed = false
s.closedChanMu.Lock() s.closedChanMu.Lock()
@@ -158,6 +164,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,
MTU: s.mtu,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {

View File

@@ -8,9 +8,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"

View File

@@ -15,9 +15,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -44,6 +45,7 @@ type UniversalUDPMuxParams struct {
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address WGAddress wgaddr.Address
MTU uint16
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -84,7 +86,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// just ignore other packets printing an warning message. // just ignore other packets printing an warning message.
// It is a blocking method, consider running in a go routine. // It is a blocking method, consider running in a go routine.
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
buf := make([]byte, 1500) buf := make([]byte, m.params.MTU+bufsize.WGBufferOverhead)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -0,0 +1,9 @@
package bufsize
const (
// WGBufferOverhead represents the additional buffer space needed beyond MTU
// for WireGuard packet encapsulation (WG header + UDP + IP + safety margin)
// Original hardcoded buffers were 1500, default MTU is 1280, so overhead = 220
// TODO: Calculate this properly based on actual protocol overhead instead of using hardcoded difference
WGBufferOverhead = 220
)

View File

@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
} }
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil return time.Unix(sec, 0), nil
} }

View File

@@ -17,6 +17,7 @@ type WGTunDevice interface {
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -21,7 +21,7 @@ type WGTunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool disableDNS bool
@@ -33,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -58,7 +58,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = "" searchDomainsToString = ""
} }
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return nil, err return nil, err
@@ -137,6 +137,10 @@ func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *WGTunDevice) MTU() uint16 {
return t.mtu
}
func (t *WGTunDevice) FilteredDevice() *FilteredDevice { func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }

View File

@@ -21,7 +21,7 @@ type TunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
@@ -30,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -42,7 +42,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
} }
func (t *TunDevice) Create() (WGConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu) tunDevice, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
@@ -111,6 +111,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -22,6 +22,7 @@ type TunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunFd int tunFd int
@@ -31,12 +32,13 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunFd: tunFd, tunFd: tunFd,
} }
@@ -125,6 +127,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement // todo implement
return nil return nil

View File

@@ -24,7 +24,7 @@ type TunKernelDevice struct {
address wgaddr.Address address wgaddr.Address
wgPort int wgPort int
key string key string
mtu int mtu uint16
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
transportNet transport.Net transportNet transport.Net
@@ -36,7 +36,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -66,7 +66,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
// TODO: do a MTU discovery // TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
if err := link.setMTU(t.mtu); err != nil { if err := link.setMTU(int(t.mtu)); err != nil {
return nil, fmt.Errorf("set mtu: %w", err) return nil, fmt.Errorf("set mtu: %w", err)
} }
@@ -96,7 +96,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter()) rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter(), t.mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -111,6 +111,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,
MTU: t.mtu,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -158,6 +159,10 @@ func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunKernelDevice) MTU() uint16 {
return t.mtu
}
func (t *TunKernelDevice) DeviceName() string { func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -1,6 +1,3 @@
//go:build !android
// +build !android
package device package device
import ( import (
@@ -22,7 +19,7 @@ type TunNetstackDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
listenAddress string listenAddress string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -35,7 +32,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -47,7 +44,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
} }
} }
func (t *TunNetstackDevice) Create() (WGConfigurer, error) { func (t *TunNetstackDevice) create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
@@ -57,7 +54,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
} }
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create() tunIface, net, err := t.nsTun.Create()
if err != nil { if err != nil {
@@ -125,6 +122,10 @@ func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunNetstackDevice) MTU() uint16 {
return t.mtu
}
func (t *TunNetstackDevice) DeviceName() string { func (t *TunNetstackDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -0,0 +1,7 @@
//go:build android
package device
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}

View File

@@ -0,0 +1,7 @@
//go:build !android
package device
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
return t.create()
}

View File

@@ -20,7 +20,7 @@ type USPDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
@@ -29,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &USPDevice{
@@ -44,9 +44,9 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
func (t *USPDevice) Create() (WGConfigurer, error) { func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu) tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil { if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) log.Debugf("failed to create tun interface (%s, %d): %s", t.name, int(t.mtu), err)
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.filteredDevice = newDeviceFilter(tunIface) t.filteredDevice = newDeviceFilter(tunIface)
@@ -118,6 +118,10 @@ func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *USPDevice) MTU() uint16 {
return t.mtu
}
func (t *USPDevice) DeviceName() string { func (t *USPDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -23,7 +23,7 @@ type TunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
@@ -33,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -59,7 +59,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, err return nil, err
} }
log.Info("create tun interface") log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu) tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, int(t.mtu))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
@@ -144,6 +144,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -15,6 +15,7 @@ type WGTunDevice interface {
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -26,6 +26,8 @@ import (
const ( const (
DefaultMTU = 1280 DefaultMTU = 1280
MinMTU = 576
MaxMTU = 8192
DefaultWgPort = 51820 DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
@@ -35,6 +37,17 @@ var (
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found") ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
) )
// ValidateMTU validates that MTU is within acceptable range
func ValidateMTU(mtu uint16) error {
if mtu < MinMTU {
return fmt.Errorf("MTU %d below minimum (%d bytes)", mtu, MinMTU)
}
if mtu > MaxMTU {
return fmt.Errorf("MTU %d exceeds maximum supported size (%d bytes)", mtu, MaxMTU)
}
return nil
}
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -45,7 +58,7 @@ type WGIFaceOpts struct {
Address string Address string
WGPort int WGPort int
WGPrivKey string WGPrivKey string
MTU int MTU uint16
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
@@ -82,6 +95,10 @@ func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
func (w *WGIface) MTU() uint16 {
return w.tun.MTU()
}
// ToInterface returns the net.Interface for the Wireguard interface // ToInterface returns the net.Interface for the Wireguard interface
func (r *WGIface) ToInterface() *net.Interface { func (r *WGIface) ToInterface() *net.Interface {
name := r.tun.DeviceName() name := r.tun.DeviceName()

View File

@@ -3,6 +3,7 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -14,7 +15,16 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,

View File

@@ -17,7 +17,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -16,10 +16,10 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }

View File

@@ -22,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{} wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -31,11 +31,11 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if device.WireGuardModuleIsLoaded() { if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort) wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
return wgIFace, nil return wgIFace, nil
} }
if device.ModuleTunIsLoaded() { if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -14,7 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
@@ -135,7 +136,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}() }()
for { for {
buf := make([]byte, 1500) buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -17,6 +17,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -29,6 +30,7 @@ const (
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
mtu uint16
ebpfManager ebpfMgr.Manager ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn turnConnStore map[uint16]net.Conn
@@ -43,10 +45,11 @@ type WGEBPFProxy struct {
} }
// NewWGEBPFProxy create new WGEBPFProxy instance // NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu,
ebpfManager: ebpf.GetEbpfManagerInstance(), ebpfManager: ebpf.GetEbpfManagerInstance(),
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
@@ -138,7 +141,7 @@ func (p *WGEBPFProxy) Free() error {
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance. // From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for p.ctx.Err() == nil { for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil { if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil { if p.ctx.Err() != nil {

View File

@@ -7,7 +7,7 @@ import (
) )
func TestWGEBPFProxy_connStore(t *testing.T) { func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(1) wgProxy := NewWGEBPFProxy(1, 1280)
p, _ := wgProxy.storeTurnConn(nil) p, _ := wgProxy.storeTurnConn(nil)
if p != 1 { if p != 1 {
@@ -27,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(1) wgProxy := NewWGEBPFProxy(1, 1280)
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535 wgProxy.lastUsedPort = 65535
@@ -43,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(1) wgProxy := NewWGEBPFProxy(1, 1280)
for i := 0; i < 65535; i++ { for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
@@ -103,7 +104,7 @@ func (e *ProxyWrapper) CloseConn() error {
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500) buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.readFromRemote(ctx, buf) n, err := p.readFromRemote(ctx, buf)
if err != nil { if err != nil {

View File

@@ -11,16 +11,18 @@ import (
type KernelFactory struct { type KernelFactory struct {
wgPort int wgPort int
mtu uint16
ebpfProxy *ebpf.WGEBPFProxy ebpfProxy *ebpf.WGEBPFProxy
} }
func NewKernelFactory(wgPort int) *KernelFactory { func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
f := &KernelFactory{ f := &KernelFactory{
wgPort: wgPort, wgPort: wgPort,
mtu: mtu,
} }
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
@@ -33,7 +35,7 @@ func NewKernelFactory(wgPort int) *KernelFactory {
func (w *KernelFactory) GetProxy() Proxy { func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil { if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
} }
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)

View File

@@ -9,19 +9,21 @@ import (
// KernelFactory todo: check eBPF support on FreeBSD // KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct { type KernelFactory struct {
wgPort int wgPort int
mtu uint16
} }
func NewKernelFactory(wgPort int) *KernelFactory { func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{ f := &KernelFactory{
wgPort: wgPort, wgPort: wgPort,
mtu: mtu,
} }
return f return f
} }
func (w *KernelFactory) GetProxy() Proxy { func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -16,7 +16,7 @@ func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) t.Fatalf("failed to initialize ebpf proxy: %s", err)
} }

View File

@@ -84,12 +84,12 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}{ }{
{ {
name: "userspace proxy", name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830), proxy: udpProxy.NewWGUDPProxy(51830, 1280),
}, },
} }
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) t.Fatalf("failed to initialize ebpf proxy: %s", err)
} }

View File

@@ -12,12 +12,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
type WGUDPProxy struct { type WGUDPProxy struct {
localWGListenPort int localWGListenPort int
mtu uint16
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
@@ -34,10 +36,11 @@ type WGUDPProxy struct {
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
func NewWGUDPProxy(wgPort int) *WGUDPProxy { func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu,
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
return p return p
@@ -144,7 +147,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
} }
}() }()
buf := make([]byte, 1500) buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for ctx.Err() == nil { for ctx.Err() == nil {
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
if err != nil { if err != nil {
@@ -179,7 +182,7 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
} }
}() }()
buf := make([]byte, 1500) buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.remoteConnRead(ctx, buf) n, err := p.remoteConnRead(ctx, buf)
if err != nil { if err != nil {

View File

@@ -3,15 +3,17 @@ package auth
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/client/internal"
"github.com/stretchr/testify/require"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
) )
type mockHTTPClient struct { type mockHTTPClient struct {

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -244,7 +245,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
}
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager) c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 { if len(relayURLs) > 0 {
if token != nil { if token != nil {
@@ -259,14 +268,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
} }
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
}
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
@@ -274,11 +275,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil { if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
@@ -444,6 +446,8 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
BlockInbound: config.BlockInbound, BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled, LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -466,6 +470,20 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
return engineConf, nil return engineConf, nil
} }
func selectMTU(localMTU uint16, peerMTU int32) uint16 {
var finalMTU uint16 = iface.DefaultMTU
if localMTU > 0 {
finalMTU = localMTU
} else if peerMTU > 0 {
finalMTU = uint16(peerMTU)
}
// Set global DNS MTU
dns.SetCurrentMTU(finalMTU)
return finalMTU
}
// connectToSignal creates Signal Service client and established a connection // connectToSignal creates Signal Service client and established a connection
func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) { func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
var sigTLSEnabled bool var sigTLSEnabled bool

View File

@@ -315,6 +315,10 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add sync response: %w", err) return fmt.Errorf("add sync response: %w", err)
} }
if err := g.addDNSConfig(); err != nil {
log.Errorf("failed to add DNS config to debug bundle: %v", err)
}
if err := g.addStateFile(); err != nil { if err := g.addStateFile(); err != nil {
log.Errorf("failed to add state file to debug bundle: %v", err) log.Errorf("failed to add state file to debug bundle: %v", err)
} }
@@ -341,6 +345,50 @@ func (g *BundleGenerator) createArchive() error {
return nil return nil
} }
// addDNSConfig writes a dns_config.json snapshot with routed domains and NS group status
func (g *BundleGenerator) addDNSConfig() error {
type nsGroup struct {
ID string `json:"id"`
Servers []string `json:"servers"`
Domains []string `json:"domains"`
Enabled bool `json:"enabled"`
Error string `json:"error,omitempty"`
}
type dnsConfig struct {
Groups []nsGroup `json:"name_server_groups"`
}
if g.statusRecorder == nil {
return nil
}
states := g.statusRecorder.GetDNSStates()
cfg := dnsConfig{Groups: make([]nsGroup, 0, len(states))}
for _, st := range states {
var servers []string
for _, ap := range st.Servers {
servers = append(servers, ap.String())
}
var errStr string
if st.Error != nil {
errStr = st.Error.Error()
}
cfg.Groups = append(cfg.Groups, nsGroup{
ID: st.ID,
Servers: servers,
Domains: st.Domains,
Enabled: st.Enabled,
Error: errStr,
})
}
bs, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal dns config: %w", err)
}
return g.addFileToZip(bytes.NewReader(bs), "dns_config.json")
}
func (g *BundleGenerator) addSystemInfo() { func (g *BundleGenerator) addSystemInfo() {
if err := g.addRoutes(); err != nil { if err := g.addRoutes(); err != nil {
log.Errorf("failed to add routes to debug bundle: %v", err) log.Errorf("failed to add routes to debug bundle: %v", err)

View File

@@ -0,0 +1,201 @@
package config
import (
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
ErrEmptyURL = errors.New("empty URL")
ErrEmptyHost = errors.New("empty host")
ErrIPNotAllowed = errors.New("IP address not allowed")
)
// ServerDomains represents the management server domains extracted from NetBird configuration
type ServerDomains struct {
Signal domain.Domain
Relay []domain.Domain
Flow domain.Domain
Stuns []domain.Domain
Turns []domain.Domain
}
// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration
func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains {
if config == nil {
return ServerDomains{}
}
domains := ServerDomains{}
domains.Signal = extractSignalDomain(config)
domains.Relay = extractRelayDomains(config)
domains.Flow = extractFlowDomain(config)
domains.Stuns = extractStunDomains(config)
domains.Turns = extractTurnDomains(config)
return domains
}
// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses
func ExtractValidDomain(rawURL string) (domain.Domain, error) {
if rawURL == "" {
return "", ErrEmptyURL
}
parsedURL, err := url.Parse(rawURL)
if err == nil {
if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" {
return domain, err
}
}
return extractFromRawString(rawURL)
}
// extractFromParsedURL handles domain extraction from successfully parsed URLs
func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) {
if parsedURL.Hostname() != "" {
return extractDomainFromHost(parsedURL.Hostname())
}
if parsedURL.Opaque == "" || parsedURL.Scheme == "" {
return "", nil
}
// Handle URLs with opaque content (e.g., stun:host:port)
if strings.Contains(parsedURL.Scheme, ".") {
// This is likely "domain.com:port" being parsed as scheme:opaque
reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque
if host, _, err := net.SplitHostPort(reconstructed); err == nil {
return extractDomainFromHost(host)
}
return extractDomainFromHost(parsedURL.Scheme)
}
// Valid scheme with opaque content (e.g., stun:host:port)
host := parsedURL.Opaque
if queryIndex := strings.Index(host, "?"); queryIndex > 0 {
host = host[:queryIndex]
}
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
return extractDomainFromHost(hostOnly)
}
return extractDomainFromHost(host)
}
// extractFromRawString handles domain extraction when URL parsing fails or returns no results
func extractFromRawString(rawURL string) (domain.Domain, error) {
if host, _, err := net.SplitHostPort(rawURL); err == nil {
return extractDomainFromHost(host)
}
return extractDomainFromHost(rawURL)
}
// extractDomainFromHost extracts domain from a host string, filtering out IP addresses
func extractDomainFromHost(host string) (domain.Domain, error) {
if host == "" {
return "", ErrEmptyHost
}
if _, err := netip.ParseAddr(host); err == nil {
return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host)
}
d, err := domain.FromString(host)
if err != nil {
return "", fmt.Errorf("invalid domain: %v", err)
}
return d, nil
}
// extractSingleDomain extracts a single domain from a URL with error logging
func extractSingleDomain(url, serviceType string) domain.Domain {
if url == "" {
return ""
}
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
return ""
}
return d
}
// extractMultipleDomains extracts multiple domains from URLs with error logging
func extractMultipleDomains(urls []string, serviceType string) []domain.Domain {
var domains []domain.Domain
for _, url := range urls {
if url == "" {
continue
}
d, err := ExtractValidDomain(url)
if err != nil {
log.Debugf("Skipping %s: %v", serviceType, err)
continue
}
domains = append(domains, d)
}
return domains
}
// extractSignalDomain extracts the signal domain from NetBird configuration.
func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Signal != nil {
return extractSingleDomain(config.Signal.Uri, "signal")
}
return ""
}
// extractRelayDomains extracts relay server domains from NetBird configuration.
func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
if config.Relay != nil {
return extractMultipleDomains(config.Relay.Urls, "relay")
}
return nil
}
// extractFlowDomain extracts the traffic flow domain from NetBird configuration.
func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain {
if config.Flow != nil {
return extractSingleDomain(config.Flow.Url, "flow")
}
return ""
}
// extractStunDomains extracts STUN server domains from NetBird configuration.
func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, stun := range config.Stuns {
if stun != nil && stun.Uri != "" {
urls = append(urls, stun.Uri)
}
}
return extractMultipleDomains(urls, "STUN")
}
// extractTurnDomains extracts TURN server domains from NetBird configuration.
func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain {
var urls []string
for _, turn := range config.Turns {
if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" {
urls = append(urls, turn.HostConfig.Uri)
}
}
return extractMultipleDomains(urls, "TURN")
}

View File

@@ -0,0 +1,213 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestExtractValidDomain(t *testing.T) {
tests := []struct {
name string
url string
expected string
expectError bool
}{
{
name: "HTTPS URL with port",
url: "https://api.netbird.io:443",
expected: "api.netbird.io",
},
{
name: "HTTP URL without port",
url: "http://signal.example.com",
expected: "signal.example.com",
},
{
name: "Host with port (no scheme)",
url: "signal.netbird.io:443",
expected: "signal.netbird.io",
},
{
name: "STUN URL",
url: "stun:stun.netbird.io:443",
expected: "stun.netbird.io",
},
{
name: "STUN URL with different port",
url: "stun:stun.netbird.io:5555",
expected: "stun.netbird.io",
},
{
name: "TURNS URL with query params",
url: "turns:turn.netbird.io:443?transport=tcp",
expected: "turn.netbird.io",
},
{
name: "TURN URL",
url: "turn:turn.example.com:3478",
expected: "turn.example.com",
},
{
name: "REL URL",
url: "rel://relay.example.com:443",
expected: "relay.example.com",
},
{
name: "RELS URL",
url: "rels://relay.netbird.io:443",
expected: "relay.netbird.io",
},
{
name: "Raw hostname",
url: "example.org",
expected: "example.org",
},
{
name: "IP address should be rejected",
url: "192.168.1.1",
expectError: true,
},
{
name: "IP address with port should be rejected",
url: "192.168.1.1:443",
expectError: true,
},
{
name: "IPv6 address should be rejected",
url: "2001:db8::1",
expectError: true,
},
{
name: "HTTP URL with IPv4 should be rejected",
url: "http://192.168.1.1:8080",
expectError: true,
},
{
name: "HTTPS URL with IPv4 should be rejected",
url: "https://10.0.0.1:443",
expectError: true,
},
{
name: "STUN URL with IPv4 should be rejected",
url: "stun:192.168.1.1:3478",
expectError: true,
},
{
name: "TURN URL with IPv4 should be rejected",
url: "turn:10.0.0.1:3478",
expectError: true,
},
{
name: "TURNS URL with IPv4 should be rejected",
url: "turns:172.16.0.1:5349",
expectError: true,
},
{
name: "HTTP URL with IPv6 should be rejected",
url: "http://[2001:db8::1]:8080",
expectError: true,
},
{
name: "HTTPS URL with IPv6 should be rejected",
url: "https://[::1]:443",
expectError: true,
},
{
name: "STUN URL with IPv6 should be rejected",
url: "stun:[2001:db8::1]:3478",
expectError: true,
},
{
name: "IPv6 with port should be rejected",
url: "[2001:db8::1]:443",
expectError: true,
},
{
name: "Localhost IPv4 should be rejected",
url: "127.0.0.1:8080",
expectError: true,
},
{
name: "Localhost IPv6 should be rejected",
url: "[::1]:443",
expectError: true,
},
{
name: "REL URL with IPv4 should be rejected",
url: "rel://192.168.1.1:443",
expectError: true,
},
{
name: "RELS URL with IPv4 should be rejected",
url: "rels://10.0.0.1:443",
expectError: true,
},
{
name: "Empty URL",
url: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ExtractValidDomain(tt.url)
if tt.expectError {
assert.Error(t, err, "Expected error for URL: %s", tt.url)
} else {
assert.NoError(t, err, "Unexpected error for URL: %s", tt.url)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url)
}
})
}
}
func TestExtractDomainFromHost(t *testing.T) {
tests := []struct {
name string
host string
expected string
expectError bool
}{
{
name: "Valid domain",
host: "example.com",
expected: "example.com",
},
{
name: "Subdomain",
host: "api.example.com",
expected: "api.example.com",
},
{
name: "IPv4 address",
host: "192.168.1.1",
expectError: true,
},
{
name: "IPv6 address",
host: "2001:db8::1",
expectError: true,
},
{
name: "Empty host",
host: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := extractDomainFromHost(tt.host)
if tt.expectError {
assert.Error(t, err, "Expected error for host: %s", tt.host)
} else {
assert.NoError(t, err, "Unexpected error for host: %s", tt.host)
assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host)
}
})
}
}

View File

@@ -11,11 +11,12 @@ import (
) )
const ( const (
PriorityLocal = 100 PriorityMgmtCache = 150
PriorityDNSRoute = 75 PriorityLocal = 100
PriorityUpstream = 50 PriorityDNSRoute = 75
PriorityDefault = 1 PriorityUpstream = 50
PriorityFallback = -100 PriorityDefault = 1
PriorityFallback = -100
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {
@@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler // If handler wants to continue, try next handler
if chainWriter.shouldContinue { if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler for domain=%s", qname) // Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue continue
} }
return return

View File

@@ -166,9 +166,10 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error { func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true) if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration") log.Errorf("Unable to get system DNS configuration")
return err return fmt.Errorf("recordSystemDNSSettings(): %w", err)
}
} }
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {

View File

@@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool {
// String returns a string representation of the local resolver // String returns a string representation of the local resolver
func (d *Resolver) String() string { func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records)) return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
} }
func (d *Resolver) Stop() {} func (d *Resolver) Stop() {}

View File

@@ -0,0 +1,360 @@
package mgmt
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
const dnsTimeout = 5 * time.Second
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
// String returns a string representation of the resolver.
func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
m.continueToNext(w, r)
return
}
m.mutex.RLock()
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
m.continueToNext(w, r)
return
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write response: %v", err)
}
}
// MatchSubdomains returns false since this resolver only handles exact domain matches
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
func (m *Resolver) MatchSubdomains() bool {
return false
}
// continueToNext signals the handler chain to continue to the next handler.
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write continue signal: %v", err)
}
}
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
}
m.mutex.Lock()
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {
return nil
}
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
if err != nil {
return fmt.Errorf("extract domain from URL: %w", err)
}
m.mutex.Lock()
m.mgmtDomain = &d
m.mutex.Unlock()
if err := m.AddDomain(ctx, d); err != nil {
return fmt.Errorf("add domain: %w", err)
}
return nil
}
// RemoveDomain removes a domain from the cache.
func (m *Resolver) RemoveDomain(d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.Lock()
defer m.mutex.Unlock()
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
defer m.mutex.RUnlock()
domainSet := make(map[domain.Domain]struct{})
for question := range m.records {
domainName := strings.TrimSuffix(question.Name, ".")
domainSet[domain.Domain(domainName)] = struct{}{}
}
domains := make(domain.List, 0, len(domainSet))
for d := range domainSet {
domains = append(domains, d)
}
return domains
}
// UpdateFromServerDomains updates the cache with server domains from network configuration.
// It merges new domains with existing ones, replacing entire domain types when updated.
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains domain.List
if len(newDomains) > 0 {
m.mutex.Lock()
if m.serverDomains == nil {
m.serverDomains = &dnsconfig.ServerDomains{}
}
updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains)
m.serverDomains = &updatedServerDomains
m.mutex.Unlock()
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
}
m.addNewDomains(ctx, newDomains)
return removedDomains, nil
}
// removeStaleDomains removes cached domains not present in the target domain list.
// Management domains are preserved and never removed during server domain updates.
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
var removedDomains domain.List
for _, currentDomain := range currentDomains {
if m.isDomainInList(currentDomain, newDomains) {
continue
}
if m.isManagementDomain(currentDomain) {
continue
}
removedDomains = append(removedDomains, currentDomain)
if err := m.RemoveDomain(currentDomain); err != nil {
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
}
}
return removedDomains
}
// mergeServerDomains merges new server domains with existing ones.
// When a domain type is provided in the new domains, it completely replaces that type.
func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains {
merged := existing
if incoming.Signal != "" {
merged.Signal = incoming.Signal
}
if len(incoming.Relay) > 0 {
merged.Relay = incoming.Relay
}
if incoming.Flow != "" {
merged.Flow = incoming.Flow
}
if len(incoming.Stuns) > 0 {
merged.Stuns = incoming.Stuns
}
if len(incoming.Turns) > 0 {
merged.Turns = incoming.Turns
}
return merged
}
// isDomainInList checks if domain exists in the list
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
for _, d := range list {
if domain.SafeString() == d.SafeString() {
return true
}
}
return false
}
// isManagementDomain checks if domain is the protected management domain
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List
if serverDomains.Signal != "" {
domains = append(domains, serverDomains.Signal)
}
for _, relay := range serverDomains.Relay {
if relay != "" {
domains = append(domains, relay)
}
}
if serverDomains.Flow != "" {
domains = append(domains, serverDomains.Flow)
}
for _, stun := range serverDomains.Stuns {
if stun != "" {
domains = append(domains, stun)
}
}
for _, turn := range serverDomains.Turns {
if turn != "" {
domains = append(domains, turn)
}
}
return domains
}

View File

@@ -0,0 +1,416 @@
package mgmt
import (
"context"
"fmt"
"net/url"
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.records)
assert.False(t, resolver.MatchSubdomains())
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string
urlStr string
expectedDom string
expectError bool
}{
{
name: "HTTPS URL with port",
urlStr: "https://api.netbird.io:443",
expectedDom: "api.netbird.io",
expectError: false,
},
{
name: "HTTP URL without port",
urlStr: "http://signal.example.com",
expectedDom: "signal.example.com",
expectError: false,
},
{
name: "URL with path",
urlStr: "https://relay.netbird.io/status",
expectedDom: "relay.netbird.io",
expectError: false,
},
{
name: "Invalid URL",
urlStr: "not-a-valid-url",
expectedDom: "not-a-valid-url",
expectError: false,
},
{
name: "Empty URL",
urlStr: "",
expectedDom: "",
expectError: true,
},
{
name: "STUN URL",
urlStr: "stun:stun.example.com:3478",
expectedDom: "stun.example.com",
expectError: false,
},
{
name: "TURN URL",
urlStr: "turn:turn.example.com:3478",
expectedDom: "turn.example.com",
expectError: false,
},
{
name: "REL URL",
urlStr: "rel://relay.example.com:443",
expectedDom: "relay.example.com",
expectError: false,
},
{
name: "RELS URL",
urlStr: "rels://relay.example.com:443",
expectedDom: "relay.example.com",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var parsedURL *url.URL
var err error
if tt.urlStr != "" {
parsedURL, err = url.Parse(tt.urlStr)
if err != nil && !tt.expectError {
t.Fatalf("Failed to parse URL: %v", err)
}
}
domain, err := extractDomainFromURL(parsedURL)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedDom, domain.SafeString())
}
})
}
}
func TestResolver_PopulateFromConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver()
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
assert.Error(t, err)
assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
// No domains should be cached when using IP addresses
domains := resolver.GetCachedDomains()
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
}
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Add a test domain to the cache - use example.org which is reserved for testing
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Test A record query for cached domain
t.Run("Cached domain A record", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
})
// Test uncached domain signals to continue to next handler
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
req := new(dns.Msg)
req.SetQuestion("unknown.example.com.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
// Zero flag set to true signals the handler chain to continue to next handler
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
})
// Test that subdomains of cached domains are NOT resolved
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query for a subdomain of our cached domain
req := new(dns.Msg)
req.SetQuestion("sub.example.org.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
})
// Test case-insensitive matching
t.Run("Case-insensitive domain matching", func(t *testing.T) {
var capturedMsg *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
capturedMsg = m
return nil
},
}
// Query with different casing
req := new(dns.Msg)
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
resolver.ServeDNS(mockWriter, req)
assert.NotNil(t, capturedMsg)
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
})
}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
testDomain, err := domain.FromString("example.org")
if err != nil {
t.Fatalf("Failed to create domain: %v", err)
}
err = resolver.AddDomain(ctx, testDomain)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
cachedDomains := resolver.GetCachedDomains()
assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
err := resolver.PopulateFromConfig(ctx, mgmtURL)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
initialDomains := resolver.GetCachedDomains()
if len(initialDomains) == 0 {
t.Skip("Management domain failed to resolve, skipping test")
}
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
assert.Equal(t, "example.org", initialDomains[0].SafeString())
serverDomains := dnsconfig.ServerDomains{
Signal: "google.com",
Relay: []domain.Domain{"cloudflare.com"},
}
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
finalDomains := resolver.GetCachedDomains()
managementStillCached := false
for _, d := range finalDomains {
if d.SafeString() == "example.org" {
managementStillCached = true
break
}
}
assert.True(t, managementStillCached, "Management domain should never be removed")
}
// extractDomainFromURL extracts a domain from a URL - test helper function
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
if u == nil {
return "", fmt.Errorf("URL is nil")
}
return dnsconfig.ExtractValidDomain(u.String())
}
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Verify domains were added
cachedDomains := resolver.GetCachedDomains()
assert.Len(t, cachedDomains, 3)
// Update with empty ServerDomains (simulating partial network map update)
emptyDomains := dnsconfig.ServerDomains{}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
assert.NoError(t, err)
// Verify no domains were removed
assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty")
// Verify all original domains are still cached
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "All original domains should still be cached")
}
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial complete domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
partialDomains := dnsconfig.ServerDomains{
Signal: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
// Should remove only the old signal domain
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
assert.Equal(t, "example.org", removedDomains[0].SafeString())
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
domainStrings[i] = d.SafeString()
}
assert.Contains(t, domainStrings, "github.com")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.NotContains(t, domainStrings, "example.org")
}
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
resolver := NewResolver()
ctx := context.Background()
// Set up initial complete domains using resolvable domains
initialDomains := dnsconfig.ServerDomains{
Signal: "example.org",
Stuns: []domain.Domain{"google.com"},
Turns: []domain.Domain{"cloudflare.com"},
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
finalDomains := resolver.GetCachedDomains()
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
domainStrings := make([]string, len(finalDomains))
for i, d := range finalDomains {
domainStrings[i] = d.SafeString()
}
assert.Contains(t, domainStrings, "example.org")
assert.Contains(t, domainStrings, "google.com")
assert.Contains(t, domainStrings, "cloudflare.com")
assert.Contains(t, domainStrings, "github.com")
}

View File

@@ -3,20 +3,23 @@ package dns
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"github.com/miekg/dns" "github.com/miekg/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
// MockServer is the mock instance of a dns server // MockServer is the mock instance of a dns server
type MockServer struct { type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func(domain.List, dns.Handler, int) RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int) DeregisterHandlerFunc func(domain.List, int)
UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error
} }
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string {
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() { func (m *MockServer) ProbeAvailability() {
} }
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
}
return nil
}
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@@ -15,7 +16,9 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/local" "github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/mgmt"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -45,6 +48,8 @@ type Server interface {
OnUpdatedHostDNSServer(addrs []netip.AddrPort) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
} }
type nsGroupsByDomain struct { type nsGroupsByDomain struct {
@@ -77,6 +82,8 @@ type DefaultServer struct {
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
mgmtCacheResolver *mgmt.Resolver
// permanent related properties // permanent related properties
permanent bool permanent bool
hostsDNSHolder *hostsDNSHolder hostsDNSHolder *hostsDNSHolder
@@ -104,18 +111,20 @@ type handlerWrapper struct {
type registeredHandlerMap map[types.HandlerID]handlerWrapper type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct {
WgInterface WGIface
CustomAddress string
StatusRecorder *peer.Status
StateManager *statemanager.Manager
DisableSys bool
}
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) {
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if config.CustomAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err)
} }
@@ -123,13 +132,14 @@ func NewDefaultServer(
} }
var dnsService service var dnsService service
if wgInterface.IsUserspaceBind() { if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface) dnsService = NewServiceViaMemory(config.WgInterface)
} else { } else {
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(config.WgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
return server, nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -178,20 +188,24 @@ func newDefaultServer(
) *DefaultServer { ) *DefaultServer {
handlerChain := NewHandlerChain() handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
ctxCancel: stop, ctxCancel: stop,
disableSys: disableSys, disableSys: disableSys,
service: dnsService, service: dnsService,
handlerChain: handlerChain, handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap), dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager, stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
} }
// register with root zone, handler chain takes care of the routing // register with root zone, handler chain takes care of the routing
@@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority) log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains)
for _, domain := range domains { for _, domain := range domains {
if domain == "" { if domain == "" {
@@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
} }
func (s *DefaultServer) deregisterHandler(domains []string, priority int) { func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority) log.Debugf("deregistering handler with priority %d for %v", priority, domains)
for _, domain := range domains { for _, domain := range domains {
if domain == "" { if domain == "" {
@@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() {
wg.Wait() wg.Wait()
} }
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
s.mux.Lock()
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
if len(removedDomains) > 0 {
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
if len(newDomains) > 0 {
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}
}
return nil
}
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver // is the service should be Disabled, we stop the listener or fake resolver
if update.ServiceEnable { if update.ServiceEnable {
@@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain {
), ),
) )
} }
// PopulateManagementDomain populates the DNS cache with management domain
func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
if s.mgmtCacheResolver != nil {
return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL)
}
return nil
}

View File

@@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: &mocWGIface{},
CustomAddress: testCase.addrPort,
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }

View File

@@ -26,10 +26,18 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
const ( var currentMTU uint16 = iface.DefaultMTU
UpstreamTimeout = 15 * time.Second
func SetCurrentMTU(mtu uint16) {
currentMTU = mtu
}
const (
UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second probeTimeout = 2 * time.Second
) )
@@ -52,9 +60,7 @@ type upstreamResolverBase struct {
upstreamServers []netip.AddrPort upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
upstreamTimeout time.Duration upstreamTimeout time.Duration
@@ -73,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
domain: domain, domain: domain,
upstreamTimeout: UpstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
} }
} }
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %s", u.upstreamServers) return fmt.Sprintf("Upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
@@ -110,58 +115,102 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID() requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID) logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.ctx.Err() != nil {
logger.Tracef("%s has been stopped", u)
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
}
select { func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
case <-u.ctx.Done(): timeout := u.upstreamTimeout
logger.Tracef("%s has been stopped", u) if len(u.upstreamServers) > 1 {
return maxTotal := 5 * time.Second
default: minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
var rm *dns.Msg if u.queryUpstream(w, r, upstream, timeout, logger) {
var t time.Duration return true
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
} }
}
return false
}
if rm == nil || !rm.Response { func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) var rm *dns.Msg
continue var t time.Duration
} var err error
u.successCount.Add(1) var startTime time.Time
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err = w.WriteMsg(rm); err != nil { if err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
} return false
// count the fails only if they happen sequentially }
u.failsCount.Store(0)
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return return
} }
u.failsCount.Add(1)
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
@@ -171,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// checkUpstreamFails counts fails and disables or enables upstream resolving
//
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
u.mutex.Lock()
defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
return
}
select {
case <-u.ctx.Done():
return
default:
}
u.disable(err)
if u.statusRecorder == nil {
return
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta
)
}
// ProbeAvailability tests all upstream servers simultaneously and // ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work // disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() { func (u *upstreamResolverBase) ProbeAvailability() {
@@ -218,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() {
default: default:
} }
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact // avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { if u.successCount.Load() > 0 {
return return
} }
@@ -306,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled = false
@@ -358,8 +371,8 @@ func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout tim
// If the passed context is nil, this will use Exchange instead of ExchangeContext. // If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers // MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower. // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = iface.DefaultMTU - (60 + 8) client.UDPSize = uint16(currentMTU - (60 + 8))
var ( var (
rm *dns.Msg rm *dns.Msg
@@ -410,3 +423,80 @@ func GenerateRequestID() string {
} }
return hex.EncodeToString(bytes) return hex.EncodeToString(bytes)
} }
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}

View File

@@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns
} }
func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{} upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }
@@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
} }
upstreamExchangeClient := &dns.Client{ upstreamExchangeClient := &dns.Client{
Dialer: dialer, Dialer: dialer,
Timeout: timeout,
} }
return upstreamExchangeClient.Exchange(r, upstream) return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -34,7 +34,10 @@ func newUpstreamResolver(
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) client := &dns.Client{
Timeout: ClientTimeout,
}
return ExchangeWithFallback(ctx, client, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {

View File

@@ -47,7 +47,9 @@ func newUpstreamResolver(
} }
func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
client := &dns.Client{} client := &dns.Client{
Timeout: ClientTimeout,
}
upstreamHost, _, err := net.SplitHostPort(upstream) upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
@@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura
}, },
} }
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
Timeout: dialTimeout,
} }
return client, nil return client, nil
} }

View File

@@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg)
} }
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{ resolver := &upstreamResolverBase{
ctx: context.TODO(), ctx: context.TODO(),
upstreamClient: &mockUpstreamResolver{ upstreamClient: mockClient,
err: nil,
r: new(dns.Msg),
rtt: time.Millisecond,
},
upstreamTimeout: UpstreamTimeout, upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod, reactivatePeriod: time.Microsecond * 100,
failsTillDeact: failsTillDeact,
} }
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil },
}
failed := false failed := false
resolver.deactivate = func(error) { resolver.deactivate = func(error) {
failed = true failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
} }
reactivated := false reactivated := false
@@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true reactivated = true
} }
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) resolver.ProbeAvailability()
if !failed { if !failed {
t.Errorf("expected that resolving was deactivated") t.Errorf("expected that resolving was deactivated")
@@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
return return
} }
if resolver.failsCount.Load() != 0 {
t.Errorf("fails count after reactivation should be 0")
return
}
if resolver.disabled { if resolver.disabled {
t.Errorf("should be enabled") t.Errorf("should be enabled")
} }

View File

@@ -46,6 +46,18 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewaller firewall firewaller
resolver resolver resolver resolver
// failure rate tracking for routed domains
failureMu sync.Mutex
failureCounts map[string]int
failureWindow time.Duration
lastLogPerHost map[string]time.Time
// per-domain rolling stats and windows
statsMu sync.Mutex
stats map[string]*domainStats
winSize time.Duration
slowT time.Duration
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -56,9 +68,25 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
failureCounts: make(map[string]int),
failureWindow: 10 * time.Second,
lastLogPerHost: make(map[string]time.Time),
stats: make(map[string]*domainStats),
winSize: 10 * time.Second,
slowT: 300 * time.Millisecond,
} }
} }
type domainStats struct {
total int
success int
timeouts int
notfound int
failures int // other failures (incl. SERVFAIL-like)
slow int
lastLog time.Time
}
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder on address=%s", f.listenAddress) log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
@@ -163,12 +191,19 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
start := time.Now()
ips, err := f.resolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
elapsed := time.Since(start)
if err != nil { if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err) f.handleDNSError(ctx, w, question, resp, domain, err)
// record error stats for routed domains
f.recordErrorStats(strings.TrimSuffix(domain, "."), err)
return nil return nil
} }
// record success timing
f.recordSuccessStats(strings.TrimSuffix(domain, "."), elapsed)
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
@@ -306,6 +341,91 @@ func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter,
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err) log.Errorf("failed to write failure DNS response: %v", err)
} }
// Track failure rate for routed domains only
if resID, _ := f.getMatchingEntries(strings.TrimSuffix(domain, ".")); resID != "" {
f.recordDomainFailure(strings.TrimSuffix(domain, "."))
}
}
// recordErrorStats updates per-domain counters and emits rate-limited logs
func (f *DNSForwarder) recordErrorStats(domain string, err error) {
domain = strings.ToLower(domain)
f.statsMu.Lock()
s := f.ensureStats(domain)
s.total++
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
if dnsErr.IsNotFound {
s.notfound++
} else if dnsErr.Timeout() {
s.timeouts++
} else {
s.failures++
}
} else {
s.failures++
}
f.maybeLogDomainStats(domain, s)
f.statsMu.Unlock()
}
// recordSuccessStats updates per-domain latency stats and slow counters, logs if needed (rate-limited)
func (f *DNSForwarder) recordSuccessStats(domain string, elapsed time.Duration) {
domain = strings.ToLower(domain)
f.statsMu.Lock()
s := f.ensureStats(domain)
s.total++
s.success++
if elapsed >= f.slowT {
s.slow++
}
f.maybeLogDomainStats(domain, s)
f.statsMu.Unlock()
}
func (f *DNSForwarder) ensureStats(domain string) *domainStats {
if ds, ok := f.stats[domain]; ok {
return ds
}
ds := &domainStats{}
f.stats[domain] = ds
return ds
}
// maybeLogDomainStats logs a compact summary per routed domain at most once per window
func (f *DNSForwarder) maybeLogDomainStats(domain string, s *domainStats) {
now := time.Now()
if !s.lastLog.IsZero() && now.Sub(s.lastLog) < f.winSize {
return
}
// check if routed (avoid logging for non-routed domains)
if resID, _ := f.getMatchingEntries(domain); resID == "" {
return
}
// only log if something noteworthy happened in the window
noteworthy := s.timeouts > 0 || s.notfound > 0 || s.failures > 0 || s.slow > 0
if !noteworthy {
s.lastLog = now
return
}
// warn on persistent problems, info otherwise
levelWarn := s.timeouts >= 3 || s.failures >= 3
if levelWarn {
log.Warnf("[d] DNS stats: domain=%s total=%d ok=%d timeout=%d nxdomain=%d fail=%d slow=%d(>=%s)",
domain, s.total, s.success, s.timeouts, s.notfound, s.failures, s.slow, f.slowT)
} else {
log.Infof("[d] DNS stats: domain=%s total=%d ok=%d timeout=%d nxdomain=%d fail=%d slow=%d(>=%s)",
domain, s.total, s.success, s.timeouts, s.notfound, s.failures, s.slow, f.slowT)
}
// reset counters for next window
*s = domainStats{lastLog: now}
} }
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records // addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
@@ -341,6 +461,27 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
} }
} }
// recordDomainFailure increments failure count for the domain and logs at info/warn with throttling.
func (f *DNSForwarder) recordDomainFailure(domain string) {
domain = strings.ToLower(domain)
f.failureMu.Lock()
defer f.failureMu.Unlock()
f.failureCounts[domain]++
count := f.failureCounts[domain]
now := time.Now()
last, ok := f.lastLogPerHost[domain]
if ok && now.Sub(last) < f.failureWindow {
return
}
f.lastLogPerHost[domain] = now
log.Warnf("[d] DNS failures observed for routed domain: domain=%s failures=%d/%s", domain, count, f.failureWindow)
}
// getMatchingEntries retrieves the resource IDs for a given domain. // getMatchingEntries retrieves the resource IDs for a given domain.
// It returns the most specific match and all matching resource IDs. // It returns the most specific match and all matching resource IDs.
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) { func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {

View File

@@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"net/url"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
@@ -17,8 +18,8 @@ import (
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -33,6 +34,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -125,6 +127,8 @@ type EngineConfig struct {
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
MTU uint16
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -343,10 +347,14 @@ func (e *Engine) Stop() error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start() error { func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if err := iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err)
}
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
} }
@@ -395,6 +403,11 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
// Populate DNS cache with NetbirdConfig and management URL for early resolution
if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache: %v", err)
}
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx, Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(), PublicKey: e.config.WgPrivateKey.PublicKey().String(),
@@ -655,6 +668,30 @@ func (e *Engine) removePeer(peerKey string) error {
return nil return nil
} }
// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response
func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error {
if e.dnsServer == nil {
return nil
}
// Populate management URL if provided
if mgmtURL != nil {
if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil {
log.Warnf("failed to populate DNS cache with management URL: %v", err)
}
}
// Populate NetbirdConfig domains if provided
if netbirdConfig != nil {
serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig)
if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil {
return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err)
}
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -686,6 +723,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("handle the flow configuration: %w", err) return fmt.Errorf("handle the flow configuration: %w", err)
} }
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal // todo update signal
} }
@@ -908,7 +949,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.LazyConnectionEnabled, e.config.LazyConnectionEnabled,
) )
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
@@ -919,7 +959,7 @@ func (e *Engine) receiveManagementEvents() {
} }
log.Debugf("stopped receiving updates from Management Service") log.Debugf("stopped receiving updates from Management Service")
}() }()
log.Debugf("connecting to Management Service updates stream") log.Infof("connecting to Management Service updates stream")
} }
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
@@ -1111,15 +1151,16 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
} }
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix.Masked(), Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains), Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade, Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute, KeepRoute: protoRoute.KeepRoute,
SkipAutoApply: protoRoute.SkipAutoApply,
} }
routes = append(routes, convertedRoute) routes = append(routes, convertedRoute)
} }
@@ -1491,7 +1532,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
Address: e.config.WgAddr, Address: e.config.WgAddr,
WGPort: e.config.WgPort, WGPort: e.config.WgPort,
WGPrivKey: e.config.WgPrivateKey.String(), WGPrivKey: e.config.WgPrivateKey.String(),
MTU: iface.DefaultMTU, MTU: e.config.MTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS, DisableDNS: e.config.DisableDNS,
@@ -1550,7 +1591,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil return dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{
WgInterface: e.wgInterface,
CustomAddress: e.config.CustomDNSAddress,
StatusRecorder: e.statusRecorder,
StateManager: e.stateManager,
DisableSys: e.config.DisableDNS,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -19,17 +19,13 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
@@ -45,9 +41,12 @@ import (
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -218,7 +217,7 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine( engine := NewEngine(
ctx, cancel, ctx, cancel,
&signal.MockClient{}, &signal.MockClient{},
@@ -230,6 +229,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
ServerSSHAllowed: true, ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
}, },
MobileDependency{}, MobileDependency{},
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
@@ -265,7 +265,7 @@ func TestEngine_SSH(t *testing.T) {
}, },
}, nil }, nil
} }
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -363,7 +363,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine( engine := NewEngine(
ctx, cancel, ctx, cancel,
&signal.MockClient{}, &signal.MockClient{},
@@ -374,6 +374,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, },
MobileDependency{}, MobileDependency{},
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
@@ -412,7 +413,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
@@ -589,12 +590,13 @@ func TestEngine_Sync(t *testing.T) {
} }
return nil return nil
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103", WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
@@ -609,7 +611,7 @@ func TestEngine_Sync(t *testing.T) {
} }
}() }()
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -753,12 +755,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName, WgIfaceName: wgIfaceName,
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
@@ -954,12 +957,13 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName, WgIfaceName: wgIfaceName,
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
@@ -1064,7 +1068,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
defer mu.Unlock() defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String()) guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid) device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start() err = engine.Start(nil, nil)
if err != nil { if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err) t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done() wg.Done()
@@ -1181,6 +1185,7 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
config: &EngineConfig{ config: &EngineConfig{
IFaceBlackList: testCase.inputBlacklistInterface, IFaceBlackList: testCase.inputBlacklistInterface,
NATExternalIPs: testCase.inputMapList, NATExternalIPs: testCase.inputMapList,
MTU: iface.DefaultMTU,
}, },
} }
parsedList := engine.parseNATExternalIPMappings() parsedList := engine.parseNATExternalIPMappings()
@@ -1481,9 +1486,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgAddr: resp.PeerConfig.Address, WgAddr: resp.PeerConfig.Address,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: wgPort, WgPort: wgPort,
MTU: iface.DefaultMTU,
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx e.ctx = ctx
return e, err return e, err
@@ -1548,7 +1554,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
@@ -1565,7 +1575,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.ExtraSettings{}, nil). Return(&types.ExtraSettings{}, nil).
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)

View File

@@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool,
return false, err return false, err
} }
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) { if isLoginNeeded(err) {
return true, nil return true, nil
} }
@@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string,
return err return err
} }
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) { if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required") log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err return err
} }
return err return nil
} }
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
@@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err return mgmClient, err
} }
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey() serverKey, err := mgmClient.GetServerPublicKey()
if err != nil { if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err) log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err return nil, nil, err
} }
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
@@ -121,8 +125,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.BlockInbound, config.BlockInbound,
config.LazyConnectionEnabled, config.LazyConnectionEnabled,
) )
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err return serverKey, loginResp, err
} }
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.

View File

@@ -6,12 +6,11 @@ import (
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
"os"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" { if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
} }

View File

@@ -0,0 +1,14 @@
package peer
import (
"os"
"strings"
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}

View File

@@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"

View File

@@ -43,13 +43,6 @@ type OfferAnswer struct {
SessionID *ICESessionID SessionID *ICESessionID
} }
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
@@ -57,7 +50,7 @@ type Handshaker struct {
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
onNewOfferListeners []func(*OfferAnswer) onNewOfferListeners []*OfferListener
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.onNewOfferListeners = append(h.onNewOfferListeners, offer) l := NewOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
@@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue continue
} }
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
} }
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -0,0 +1,62 @@
package peer
import (
"sync"
)
type callbackFunc func(remoteOfferAnswer *OfferAnswer)
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type OfferListener struct {
fn callbackFunc
running bool
latest *OfferAnswer
mu sync.Mutex
}
func NewOfferListener(fn callbackFunc) *OfferListener {
return &OfferListener{
fn: fn,
}
}
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock()
defer o.mu.Unlock()
// Store the latest offer
o.latest = remoteOfferAnswer
// If already running, the running goroutine will pick up this latest value
if o.running {
return
}
// Start processing
o.running = true
// Process in a goroutine to avoid blocking the caller
go func(remoteOfferAnswer *OfferAnswer) {
for {
o.fn(remoteOfferAnswer)
o.mu.Lock()
if o.latest == nil {
// No more work to do
o.running = false
o.mu.Unlock()
return
}
remoteOfferAnswer = o.latest
// Clear the latest to mark it as being processed
o.latest = nil
o.mu.Unlock()
}
}(remoteOfferAnswer)
}

View File

@@ -0,0 +1,39 @@
package peer
import (
"testing"
"time"
)
func Test_newOfferListener(t *testing.T) {
dummyOfferAnswer := &OfferAnswer{}
runChan := make(chan struct{}, 10)
longRunningFn := func(remoteOfferAnswer *OfferAnswer) {
time.Sleep(1 * time.Second)
runChan <- struct{}{}
}
hl := NewOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)
// Wait for exactly 2 callbacks
for i := 0; i < 2; i++ {
select {
case <-runChan:
case <-time.After(3 * time.Second):
t.Fatal("Timeout waiting for callback")
}
}
// Verify no additional callbacks happen
select {
case <-runChan:
t.Fatal("Unexpected additional callback")
case <-time.After(100 * time.Millisecond):
t.Log("Correctly received exactly 2 callbacks")
}
}

View File

@@ -3,7 +3,7 @@ package ice
import ( import (
"sync/atomic" "sync/atomic"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
) )
type StunTurn atomic.Value type StunTurn atomic.Value

View File

@@ -1,9 +1,10 @@
package ice package ice
import ( import (
"sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/randutil" "github.com/pion/randutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -23,7 +24,20 @@ const (
iceRelayAcceptanceMinWaitDefault = 2 * time.Second iceRelayAcceptanceMinWaitDefault = 2 * time.Second
) )
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { type ThreadSafeAgent struct {
*ice.Agent
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
err = a.Agent.Close()
})
return err
}
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive() iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout() iceDisconnectedTimeout := iceDisconnectedTimeout()
iceFailedTimeout := iceFailedTimeout() iceFailedTimeout := iceFailedTimeout()
@@ -61,7 +75,12 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
} }
return ice.NewAgent(agentConfig) agent, err := ice.NewAgent(agentConfig)
if err != nil {
return nil, err
}
return &ThreadSafeAgent{Agent: agent}, nil
} }
func GenerateICECredentials() (string, string, error) { func GenerateICECredentials() (string, string, error) {

View File

@@ -1,7 +1,7 @@
package ice package ice
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
) )
type Config struct { type Config struct {

View File

@@ -1,7 +1,7 @@
package peer package peer
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
) )
const eventQueueSize = 10 const eventQueueSize = 10
@@ -201,6 +201,8 @@ type Status struct {
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
lazyConnectionEnabled bool lazyConnectionEnabled bool
lastDisconnectLog map[string]time.Time
// To reduce the number of notification invocation this bool will be true when need to call the notification // To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications() // set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
@@ -229,6 +231,7 @@ func NewRecorder(mgmAddress string) *Status {
notifier: newNotifier(), notifier: newNotifier(),
mgmAddress: mgmAddress, mgmAddress: mgmAddress,
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{}, resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
lastDisconnectLog: make(map[string]time.Time),
} }
} }
@@ -487,6 +490,9 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
// info log about disconnect with impacted routes (throttled)
d.logPeerDisconnectIfNeeded(receivedState.PubKey, peerState)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged() d.notifyPeerListChanged()
} }
@@ -519,6 +525,9 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
// info log about disconnect with impacted routes (throttled)
d.logPeerDisconnectIfNeeded(receivedState.PubKey, peerState)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged() d.notifyPeerListChanged()
} }
@@ -529,6 +538,49 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil return nil
} }
// logPeerDisconnectIfNeeded logs an info message when a routing peer transitions to disconnected
// with the number of impacted routes. Throttled to once per peer per 30 seconds.
func (d *Status) logPeerDisconnectIfNeeded(pubKey string, state State) {
if state.ConnStatus != StatusIdle {
return
}
now := time.Now()
last, ok := d.lastDisconnectLog[pubKey]
if ok && now.Sub(last) < 10*time.Second {
return
}
d.lastDisconnectLog[pubKey] = now
routes := state.GetRoutes()
numRoutes := len(routes)
fqdn := state.FQDN
if fqdn == "" {
fqdn = pubKey
}
// prepare a bounded list of impacted routes to avoid huge log lines
maxList := 20
list := make([]string, 0, maxList)
for r := range routes {
if len(list) >= maxList {
break
}
list = append(list, r)
}
more := ""
if numRoutes > len(list) {
more = ", more=" + fmt.Sprintf("%d", numRoutes-len(list))
}
if len(list) > 0 {
log.Warnf("[d] Routing peer disconnected: peer=%s impacted_routes=%d routes=%v%s", fqdn, numRoutes, list, more)
} else {
log.Warnf("[d] Routing peer disconnected: peer=%s impacted_routes=%d", fqdn, numRoutes)
}
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state // UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error { func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
d.mux.Lock() d.mux.Lock()

View File

@@ -30,9 +30,10 @@ type WGWatcher struct {
peerKey string peerKey string
stateDump *stateDump stateDump *stateDump
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxLock sync.Mutex ctxLock sync.Mutex
enabledTime time.Time
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher") w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock() w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil { if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled") w.log.Errorf("WireGuard watcher already enabled")
@@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
onDisconnectedFn() onDisconnectedFn()
return return
} }
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
lastHandshake = *handshake lastHandshake = *handshake
resetTime := time.Until(handshake.Add(checkPeriod)) resetTime := time.Until(handshake.Add(checkPeriod))

View File

@@ -8,7 +8,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v4"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -42,7 +42,7 @@ type WorkerICE struct {
statusRecorder *Status statusRecorder *Status
hasRelayOnLocally bool hasRelayOnLocally bool
agent *ice.Agent agent *icemaker.ThreadSafeAgent
agentDialerCancel context.CancelFunc agentDialerCancel context.CancelFunc
agentConnecting bool // while it is true, drop all incoming offers agentConnecting bool // while it is true, drop all incoming offers
lastSuccess time.Time // with this avoid the too frequent ICE agent recreation lastSuccess time.Time // with this avoid the too frequent ICE agent recreation
@@ -121,7 +121,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil { if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
// todo consider to switch to Relay connection while establishing a new ICE connection w.agent = nil
} }
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
@@ -195,7 +195,7 @@ func (w *WorkerICE) Close() {
w.agent = nil w.agent = nil
} }
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil { if err != nil {
return nil, fmt.Errorf("create agent: %w", err) return nil, fmt.Errorf("create agent: %w", err)
@@ -213,10 +213,6 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err return nil, err
} }
if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil return agent, nil
} }
@@ -230,7 +226,7 @@ func (w *WorkerICE) SessionID() ICESessionID {
// will block until connection succeeded // will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state, // but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection // so we have to cancel it with the provided context once agent detected a broken connection
func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("gather candidates") w.log.Debugf("gather candidates")
if err := agent.GatherCandidates(); err != nil { if err := agent.GatherCandidates(); err != nil {
w.log.Warnf("failed to gather candidates: %s", err) w.log.Warnf("failed to gather candidates: %s", err)
@@ -239,7 +235,7 @@ func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAn
} }
w.log.Debugf("turn agent dial") w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(ctx, remoteOfferAnswer) remoteConn, err := w.turnAgentDial(ctx, agent, remoteOfferAnswer)
if err != nil { if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err) w.log.Debugf("failed to dial the remote peer: %s", err)
w.closeAgent(agent, w.agentDialerCancel) w.closeAgent(agent, w.agentDialerCancel)
@@ -252,6 +248,11 @@ func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAn
w.closeAgent(agent, w.agentDialerCancel) w.closeAgent(agent, w.agentDialerCancel)
return return
} }
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.closeAgent(agent, w.agentDialerCancel)
return
}
if !isRelayCandidate(pair.Local) { if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one // dynamically set remote WireGuard port if other side specified a different one from the default one
@@ -290,13 +291,14 @@ func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAn
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
} }
func (w *WorkerICE) closeAgent(agent *ice.Agent, cancel context.CancelFunc) { func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
cancel() cancel()
if err := agent.Close(); err != nil { if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.muxAgent.Lock() w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID() sessionID, err := NewICESessionID()
if err != nil { if err != nil {
w.log.Errorf("failed to create new session ID: %s", err) w.log.Errorf("failed to create new session ID: %s", err)
@@ -377,16 +379,40 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key) w.config.Key)
w.muxAgent.Lock()
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
w.log.Warnf("failed to get selected candidate pair: %s", err)
w.muxAgent.Unlock()
return
}
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.muxAgent.Unlock()
return
}
w.muxAgent.Unlock()
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
} }
func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) { return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state { switch state {
case ice.ConnectionStateConnected: case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected w.lastKnownState = ice.ConnectionStateConnected
return return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()
@@ -398,13 +424,6 @@ func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel conte
} }
} }
func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true return true
@@ -412,18 +431,18 @@ func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool
return false return false
} }
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key isControlling := w.config.LocalKey > w.config.Key
if isControlling { if isControlling {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }
} }
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(), Network: candidate.NetworkType().String(),
Address: candidate.Address(), Address: candidate.Address(),
Port: relatedAdd.Port, Port: relatedAdd.Port,
@@ -431,6 +450,17 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelAddr: relatedAdd.Address, RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port, RelPort: relatedAdd.Port,
}) })
if err != nil {
return nil, err
}
for _, e := range candidate.Extensions() {
if err := ec.AddExtension(e); err != nil {
return nil, err
}
}
return ec, nil
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {

View File

@@ -75,6 +75,8 @@ type ConfigInput struct {
DNSLabels domain.List DNSLabels domain.List
LazyConnectionEnabled *bool LazyConnectionEnabled *bool
MTU *uint16
} }
// Config Configuration type // Config Configuration type
@@ -141,6 +143,8 @@ type Config struct {
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool LazyConnectionEnabled bool
MTU uint16
} }
var ConfigDirOverride string var ConfigDirOverride string
@@ -493,6 +497,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.MTU != nil && *input.MTU != config.MTU {
log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU)
config.MTU = *input.MTU
updated = true
} else if config.MTU == 0 {
config.MTU = iface.DefaultMTU
log.Infof("using default MTU %d", config.MTU)
updated = true
}
return updated, nil return updated, nil
} }

View File

@@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pion/stun/v2" "github.com/pion/stun/v3"
"github.com/pion/turn/v3" "github.com/pion/turn/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"

View File

@@ -2,11 +2,13 @@ package dnsinterceptor
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -26,6 +28,8 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const dnsTimeout = 8 * time.Second
type domainMap map[domain.Domain][]netip.Prefix type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface { type internalDNATer interface {
@@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil { if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return return
@@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil { if err != nil {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
@@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
} }
return return
} }
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""
}
peerState, err := d.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
}
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
}

View File

@@ -36,8 +36,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -368,7 +368,11 @@ func (m *DefaultManager) UpdateRoutes(
var merr *multierror.Error var merr *multierror.Error
if !m.disableClientRoutes { if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(clientRoutes)
// Update route selector based on management server's isSelected status
m.updateRouteSelectorFromManagement(clientRoutes)
filteredClientRoutes := m.routeSelector.FilterSelectedExitNodes(clientRoutes)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err)) merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
@@ -430,7 +434,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks) networks = m.routeSelector.FilterSelectedExitNodes(networks)
m.notifier.OnNewRoutes(networks) m.notifier.OnNewRoutes(networks)
@@ -583,3 +587,106 @@ func resolveURLsToIPs(urls []string) []net.IP {
} }
return ips return ips
} }
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
return
}
m.updateExitNodeSelections(exitNodeInfo)
m.logExitNodeUpdate(exitNodeInfo)
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID
userSelected []route.NetID
userDeselected []route.NetID
}
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
var info exitNodeInfo
for haID, routes := range clientRoutes {
if !m.isExitNodeRoute(routes) {
continue
}
netID := haID.NetID()
info.allIDs = append(info.allIDs, netID)
if m.routeSelector.HasUserSelectionForRoute(netID) {
m.categorizeUserSelection(netID, &info)
} else {
m.checkManagementSelection(routes, netID, &info)
}
}
return info
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {
if m.routeSelector.IsSelected(netID) {
info.userSelected = append(info.userSelected, netID)
} else {
info.userDeselected = append(info.userDeselected, netID)
}
}
func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID route.NetID, info *exitNodeInfo) {
for _, route := range routes {
if !route.SkipAutoApply {
info.selectedByManagement = append(info.selectedByManagement, netID)
break
}
}
}
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
m.deselectExitNodes(routesToDeselect)
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
}
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
var routesToDeselect []route.NetID
for _, netID := range allIDs {
if !m.routeSelector.HasUserSelectionForRoute(netID) {
routesToDeselect = append(routesToDeselect, netID)
}
}
return routesToDeselect
}
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
if len(routesToDeselect) == 0 {
return
}
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
if err != nil {
log.Warnf("Failed to deselect exit nodes: %v", err)
}
}
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
if len(selectedByManagement) == 0 {
return
}
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
if err != nil {
log.Warnf("Failed to select exit nodes: %v", err)
}
}
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
}

View File

@@ -190,14 +190,15 @@ func TestManagerUpdateRoutes(t *testing.T) {
name: "No Small Client Route Should Be Added", name: "No Small Client Route Should Be Added",
inputRoutes: []*route.Route{ inputRoutes: []*route.Route{
{ {
ID: "a", ID: "a",
NetID: "routeA", NetID: "routeA",
Peer: remotePeerKey1, Peer: remotePeerKey1,
Network: netip.MustParsePrefix("0.0.0.0/0"), Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network, NetworkType: route.IPv4Network,
Metric: 9999, Metric: 9999,
Masquerade: false, Masquerade: false,
Enabled: true, Enabled: true,
SkipAutoApply: false,
}, },
}, },
inputSerial: 1, inputSerial: 1,

View File

@@ -336,7 +336,7 @@ func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
if e1 != 0 { if e1 != 0 {
return fmt.Errorf("CreateIpForwardEntry2: %w", e1) return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
} }
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1) return fmt.Errorf("CreateIpForwardEntry2: code %d", windows.NTStatus(r1))
} }
return nil return nil
} }

View File

@@ -13,4 +13,6 @@ var (
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
ExitNodeCIDR = "0.0.0.0/0"
) )

View File

@@ -9,19 +9,27 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct { type RouteSelector struct {
mu sync.RWMutex mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{} deselectedRoutes map[route.NetID]struct{}
selectedRoutes map[route.NetID]struct{}
deselectAll bool deselectAll bool
} }
func NewRouteSelector() *RouteSelector { func NewRouteSelector() *RouteSelector {
return &RouteSelector{ return &RouteSelector{
deselectedRoutes: map[route.NetID]struct{}{}, deselectedRoutes: map[route.NetID]struct{}{},
selectedRoutes: map[route.NetID]struct{}{},
deselectAll: false, deselectAll: false,
} }
} }
@@ -32,7 +40,14 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
defer rs.mu.Unlock() defer rs.mu.Unlock()
if !appendRoute || rs.deselectAll { if !appendRoute || rs.deselectAll {
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes) maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
for _, r := range allRoutes { for _, r := range allRoutes {
rs.deselectedRoutes[r] = struct{}{} rs.deselectedRoutes[r] = struct{}{}
} }
@@ -45,6 +60,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
continue continue
} }
delete(rs.deselectedRoutes, route) delete(rs.deselectedRoutes, route)
rs.selectedRoutes[route] = struct{}{}
} }
rs.deselectAll = false rs.deselectAll = false
@@ -58,7 +74,14 @@ func (rs *RouteSelector) SelectAllRoutes() {
defer rs.mu.Unlock() defer rs.mu.Unlock()
rs.deselectAll = false rs.deselectAll = false
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes) maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
@@ -77,6 +100,7 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
continue continue
} }
rs.deselectedRoutes[route] = struct{}{} rs.deselectedRoutes[route] = struct{}{}
delete(rs.selectedRoutes, route)
} }
return errors.FormatErrorOrNil(err) return errors.FormatErrorOrNil(err)
@@ -88,7 +112,14 @@ func (rs *RouteSelector) DeselectAllRoutes() {
defer rs.mu.Unlock() defer rs.mu.Unlock()
rs.deselectAll = true rs.deselectAll = true
if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes) maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.
@@ -97,11 +128,14 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
defer rs.mu.RUnlock() defer rs.mu.RUnlock()
if rs.deselectAll { if rs.deselectAll {
log.Debugf("Route %s not selected (deselect all)", routeID)
return false return false
} }
_, deselected := rs.deselectedRoutes[routeID] _, deselected := rs.deselectedRoutes[routeID]
return !deselected isSelected := !deselected
log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected)
return isSelected
} }
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
@@ -124,15 +158,98 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
return filtered return filtered
} }
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return route.HAMap{}
}
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
continue
}
if !isExitNode(rt) {
filtered[id] = rt
continue
}
rs.applyExitNodeFilter(id, netID, rt, filtered)
}
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface // MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) { func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock() rs.mu.RLock()
defer rs.mu.RUnlock() defer rs.mu.RUnlock()
return json.Marshal(struct { return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
DeselectAll bool `json:"deselect_all"` DeselectAll bool `json:"deselect_all"`
}{ }{
SelectedRoutes: rs.selectedRoutes,
DeselectedRoutes: rs.deselectedRoutes, DeselectedRoutes: rs.deselectedRoutes,
DeselectAll: rs.deselectAll, DeselectAll: rs.deselectAll,
}) })
@@ -147,11 +264,13 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
// Check for null or empty JSON // Check for null or empty JSON
if len(data) == 0 || string(data) == "null" { if len(data) == 0 || string(data) == "null" {
rs.deselectedRoutes = map[route.NetID]struct{}{} rs.deselectedRoutes = map[route.NetID]struct{}{}
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.deselectAll = false rs.deselectAll = false
return nil return nil
} }
var temp struct { var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
DeselectAll bool `json:"deselect_all"` DeselectAll bool `json:"deselect_all"`
} }
@@ -160,12 +279,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
return err return err
} }
rs.selectedRoutes = temp.SelectedRoutes
rs.deselectedRoutes = temp.DeselectedRoutes rs.deselectedRoutes = temp.DeselectedRoutes
rs.deselectAll = temp.DeselectAll rs.deselectAll = temp.DeselectAll
if rs.deselectedRoutes == nil { if rs.deselectedRoutes == nil {
rs.deselectedRoutes = map[route.NetID]struct{}{} rs.deselectedRoutes = map[route.NetID]struct{}{}
} }
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package routeselector_test package routeselector_test
import ( import (
"net/netip"
"slices" "slices"
"testing" "testing"
@@ -273,6 +274,62 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
}, filtered) }, filtered)
} }
func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Create test routes
exitNode1 := &route.Route{
ID: "route1",
NetID: "net1",
Network: netip.MustParsePrefix("0.0.0.0/0"),
Peer: "peer1",
SkipAutoApply: false,
}
exitNode2 := &route.Route{
ID: "route2",
NetID: "net1",
Network: netip.MustParsePrefix("0.0.0.0/0"),
Peer: "peer2",
SkipAutoApply: true,
}
normalRoute := &route.Route{
ID: "route3",
NetID: "net2",
Network: netip.MustParsePrefix("192.168.1.0/24"),
Peer: "peer3",
SkipAutoApply: false,
}
routes := route.HAMap{
"net1|0.0.0.0/0": {exitNode1, exitNode2},
"net2|192.168.1.0/24": {normalRoute},
}
// Test filtering
filtered := rs.FilterSelectedExitNodes(routes)
// Should only include selected exit nodes and all normal routes
assert.Len(t, filtered, 2)
assert.Len(t, filtered["net1|0.0.0.0/0"], 1) // Only the selected exit node
assert.Equal(t, exitNode1.ID, filtered["net1|0.0.0.0/0"][0].ID)
assert.Len(t, filtered["net2|192.168.1.0/24"], 1) // Normal route should be included
assert.Equal(t, normalRoute.ID, filtered["net2|192.168.1.0/24"][0].ID)
// Test with deselected routes
err := rs.DeselectRoutes([]route.NetID{"net1"}, []route.NetID{"net1", "net2"})
assert.NoError(t, err)
filtered = rs.FilterSelectedExitNodes(routes)
assert.Len(t, filtered, 1) // Only normal route should remain
assert.Len(t, filtered["net2|192.168.1.0/24"], 1)
assert.Equal(t, normalRoute.ID, filtered["net2|192.168.1.0/24"][0].ID)
// Test with deselect all
rs = routeselector.NewRouteSelector()
rs.DeselectAllRoutes()
filtered = rs.FilterSelectedExitNodes(routes)
assert.Len(t, filtered, 0) // No routes should be selected
}
func TestRouteSelector_NewRoutesBehavior(t *testing.T) { func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"} initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}

View File

@@ -9,6 +9,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
) )
@@ -32,9 +33,15 @@ type Net struct {
// NewNetWithDiscover creates a new StdNet instance. // NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
n := &Net{ n := &Net{
iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover),
interfaceFilter: InterfaceFilter(disallowList), interfaceFilter: InterfaceFilter(disallowList),
} }
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
// so in android cli use pionDiscover
if netstack.IsEnabled() {
n.iFaceDiscover = pionDiscover{}
} else {
n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
}
return n, n.UpdateInterfaces() return n, n.UpdateInterfaces()
} }

View File

@@ -278,6 +278,7 @@ type LoginRequest struct {
BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"` BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"`
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -530,6 +531,13 @@ func (x *LoginRequest) GetUsername() string {
return "" return ""
} }
func (x *LoginRequest) GetMtu() int64 {
if x != nil && x.Mtu != nil {
return *x.Mtu
}
return 0
}
type LoginResponse struct { type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -1034,6 +1042,7 @@ type GetConfigResponse struct {
AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"`
InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"`
WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"`
Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"`
DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"`
ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
@@ -1129,6 +1138,13 @@ func (x *GetConfigResponse) GetWireguardPort() int64 {
return 0 return 0
} }
func (x *GetConfigResponse) GetMtu() int64 {
if x != nil {
return x.Mtu
}
return 0
}
func (x *GetConfigResponse) GetDisableAutoConnect() bool { func (x *GetConfigResponse) GetDisableAutoConnect() bool {
if x != nil { if x != nil {
return x.DisableAutoConnect return x.DisableAutoConnect
@@ -3679,6 +3695,7 @@ type SetConfigRequest struct {
// cleanDNSLabels clean map list of DNS labels. // cleanDNSLabels clean map list of DNS labels.
CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"`
Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -3902,6 +3919,13 @@ func (x *SetConfigRequest) GetDnsRouteInterval() *durationpb.Duration {
return nil return nil
} }
func (x *SetConfigRequest) GetMtu() int64 {
if x != nil && x.Mtu != nil {
return *x.Mtu
}
return 0
}
type SetConfigResponse struct { type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
@@ -4575,7 +4599,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" + const file_daemon_proto_rawDesc = "" +
"\n" + "\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\xa4\x0e\n" + "\fEmptyRequest\"\xc3\x0e\n" +
"\fLoginRequest\x12\x1a\n" + "\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4611,7 +4635,8 @@ const file_daemon_proto_rawDesc = "" +
"\x15lazyConnectionEnabled\x18\x1c \x01(\bH\x0fR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" + "\x15lazyConnectionEnabled\x18\x1c \x01(\bH\x0fR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" +
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01B\x13\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" + "\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" + "\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" + "\x0e_wireguardPortB\x17\n" +
@@ -4630,7 +4655,8 @@ const file_daemon_proto_rawDesc = "" +
"\x16_lazyConnectionEnabledB\x10\n" + "\x16_lazyConnectionEnabledB\x10\n" +
"\x0e_block_inboundB\x0e\n" + "\x0e_block_inboundB\x0e\n" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_username\"\xb5\x01\n" + "\t_usernameB\x06\n" +
"\x04_mtu\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" + "\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
@@ -4661,7 +4687,7 @@ const file_daemon_proto_rawDesc = "" +
"\fDownResponse\"P\n" + "\fDownResponse\"P\n" +
"\x10GetConfigRequest\x12 \n" + "\x10GetConfigRequest\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername\"\xa3\x06\n" + "\busername\x18\x02 \x01(\tR\busername\"\xb5\x06\n" +
"\x11GetConfigResponse\x12$\n" + "\x11GetConfigResponse\x12$\n" +
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
"\n" + "\n" +
@@ -4671,7 +4697,8 @@ const file_daemon_proto_rawDesc = "" +
"\fpreSharedKey\x18\x04 \x01(\tR\fpreSharedKey\x12\x1a\n" + "\fpreSharedKey\x18\x04 \x01(\tR\fpreSharedKey\x12\x1a\n" +
"\badminURL\x18\x05 \x01(\tR\badminURL\x12$\n" + "\badminURL\x18\x05 \x01(\tR\badminURL\x12$\n" +
"\rinterfaceName\x18\x06 \x01(\tR\rinterfaceName\x12$\n" + "\rinterfaceName\x18\x06 \x01(\tR\rinterfaceName\x12$\n" +
"\rwireguardPort\x18\a \x01(\x03R\rwireguardPort\x12.\n" + "\rwireguardPort\x18\a \x01(\x03R\rwireguardPort\x12\x10\n" +
"\x03mtu\x18\b \x01(\x03R\x03mtu\x12.\n" +
"\x12disableAutoConnect\x18\t \x01(\bR\x12disableAutoConnect\x12*\n" + "\x12disableAutoConnect\x18\t \x01(\bR\x12disableAutoConnect\x12*\n" +
"\x10serverSSHAllowed\x18\n" + "\x10serverSSHAllowed\x18\n" +
" \x01(\bR\x10serverSSHAllowed\x12*\n" + " \x01(\bR\x10serverSSHAllowed\x12*\n" +
@@ -4885,7 +4912,7 @@ const file_daemon_proto_rawDesc = "" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_username\"\x17\n" + "\t_username\"\x17\n" +
"\x15SwitchProfileResponse\"\xef\f\n" + "\x15SwitchProfileResponse\"\x8e\r\n" +
"\x10SetConfigRequest\x12\x1a\n" + "\x10SetConfigRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
@@ -4917,7 +4944,8 @@ const file_daemon_proto_rawDesc = "" +
"\n" + "\n" +
"dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" +
"\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" +
"\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01B\x13\n" + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" +
"\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" + "\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" + "\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" + "\x0e_wireguardPortB\x17\n" +
@@ -4934,7 +4962,8 @@ const file_daemon_proto_rawDesc = "" +
"\x16_disable_notificationsB\x18\n" + "\x16_disable_notificationsB\x18\n" +
"\x16_lazyConnectionEnabledB\x10\n" + "\x16_lazyConnectionEnabledB\x10\n" +
"\x0e_block_inboundB\x13\n" + "\x0e_block_inboundB\x13\n" +
"\x11_dnsRouteInterval\"\x13\n" + "\x11_dnsRouteIntervalB\x06\n" +
"\x04_mtu\"\x13\n" +
"\x11SetConfigResponse\"Q\n" + "\x11SetConfigResponse\"Q\n" +
"\x11AddProfileRequest\x12\x1a\n" + "\x11AddProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" +

View File

@@ -156,6 +156,8 @@ message LoginRequest {
optional string profileName = 30; optional string profileName = 30;
optional string username = 31; optional string username = 31;
optional int64 mtu = 32;
} }
message LoginResponse { message LoginResponse {
@@ -223,6 +225,8 @@ message GetConfigResponse {
int64 wireguardPort = 7; int64 wireguardPort = 7;
int64 mtu = 8;
bool disableAutoConnect = 9; bool disableAutoConnect = 9;
bool serverSSHAllowed = 10; bool serverSSHAllowed = 10;
@@ -583,6 +587,7 @@ message SetConfigRequest {
optional google.protobuf.Duration dnsRouteInterval = 27; optional google.protobuf.Duration dnsRouteInterval = 27;
optional int64 mtu = 28;
} }
message SetConfigResponse{} message SetConfigResponse{}

View File

@@ -400,6 +400,11 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound config.BlockInbound = msg.BlockInbound
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
config.MTU = &mtu
}
if _, err := profilemanager.UpdateConfig(config); err != nil { if _, err := profilemanager.UpdateConfig(config); err != nil {
log.Errorf("failed to update profile config: %v", err) log.Errorf("failed to update profile config: %v", err)
return nil, fmt.Errorf("failed to update profile config: %w", err) return nil, fmt.Errorf("failed to update profile config: %w", err)
@@ -484,6 +489,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
// nolint // nolint
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, msg.Hostname) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, msg.Hostname)
} }
s.mutex.Unlock() s.mutex.Unlock()
config, err := s.getConfig(activeProf) config, err := s.getConfig(activeProf)
@@ -1105,6 +1111,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
AdminURL: adminURL.String(), AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface, InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort), WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect, DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed, ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled, RosenpassEnabled: cfg.RosenpassEnabled,

View File

@@ -10,25 +10,24 @@ import (
"time" "time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto" daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -294,15 +293,20 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
@@ -95,14 +96,6 @@ func (i *Info) SetFlags(
i.LazyConnectionEnabled = lazyConnectionEnabled i.LazyConnectionEnabled = lazyConnectionEnabled
} }
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string { func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx) md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -180,6 +173,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
// GetInfoWithChecks retrieves and parses the system information with applied checks. // GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))
processCheckPaths := make([]string, 0) processCheckPaths := make([]string, 0)
for _, check := range checks { for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...) processCheckPaths = append(processCheckPaths, check.GetFiles()...)
@@ -189,16 +183,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("gathering process check information completed")
info := GetInfo(ctx) info := GetInfo(ctx)
info.Files = files info.Files = files
log.Debugf("all system information gathered successfully")
return info, nil return info, nil
} }
// UpdateStaticInfo asynchronously updates static system and platform information
func UpdateStaticInfo() {
go func() {
_ = updateStaticInfo()
}()
}

View File

@@ -15,6 +15,11 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
kernel := "android" kernel := "android"

View File

@@ -19,6 +19,10 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
func UpdateStaticInfoAsync() {
go updateStaticInfo()
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
utsname := unix.Utsname{} utsname := unix.Utsname{}
@@ -41,7 +45,7 @@ func GetInfo(ctx context.Context) *Info {
} }
start := time.Now() start := time.Now()
si := updateStaticInfo() si := getStaticInfo()
if time.Since(start) > 1*time.Second { if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start)) log.Warnf("updateStaticInfo took %s", time.Since(start))
} }

View File

@@ -18,6 +18,11 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
out := _getInfo() out := _getInfo()

View File

@@ -10,6 +10,11 @@ import (
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {

View File

@@ -23,6 +23,10 @@ var (
getSystemInfo = defaultSysInfoImplementation getSystemInfo = defaultSysInfoImplementation
) )
func UpdateStaticInfoAsync() {
go updateStaticInfo()
}
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
info := _getInfo() info := _getInfo()
@@ -48,7 +52,7 @@ func GetInfo(ctx context.Context) *Info {
} }
start := time.Now() start := time.Now()
si := updateStaticInfo() si := getStaticInfo()
if time.Since(start) > 1*time.Second { if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start)) log.Warnf("updateStaticInfo took %s", time.Since(start))
} }

View File

@@ -2,187 +2,51 @@ package system
import ( import (
"context" "context"
"fmt"
"os" "os"
"runtime" "runtime"
"strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type Win32_OperatingSystem struct { func UpdateStaticInfoAsync() {
Caption string go updateStaticInfo()
}
type Win32_ComputerSystem struct {
Manufacturer string
}
type Win32_ComputerSystemProduct struct {
Name string
}
type Win32_BIOS struct {
SerialNumber string
} }
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
osName, osVersion := getOSNameAndVersion()
buildVersion := getBuildVersion()
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
start := time.Now() start := time.Now()
si := updateStaticInfo() si := getStaticInfo()
if time.Since(start) > 1*time.Second { if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start)) log.Warnf("updateStaticInfo took %s", time.Since(start))
} }
gio := &Info{ gio := &Info{
Kernel: "windows", Kernel: "windows",
OSVersion: osVersion, OSVersion: si.OSVersion,
Platform: "unknown", Platform: "unknown",
OS: osName, OS: si.OSName,
GoOS: runtime.GOOS, GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
KernelVersion: buildVersion, KernelVersion: si.BuildVersion,
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber, SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName, SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment, Environment: si.Environment,
} }
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
} else {
gio.NetworkAddresses = addrs
}
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname) gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.NetbirdVersion = version.NetbirdVersion() gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)
return gio return gio
} }
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
log.Error(err)
return "Windows", getBuildVersion()
}
if len(dst) == 0 {
return "Windows", getBuildVersion()
}
split := strings.Split(dst[0].Caption, " ")
if len(split) <= 3 {
return "Windows", getBuildVersion()
}
name := split[1]
version := split[2]
if split[2] == "Server" {
name = fmt.Sprintf("%s %s", split[1], split[2])
version = split[3]
}
return name, version
}
func getBuildVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}
func sysNumber() (string, error) {
var dst []Win32_BIOS
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].SerialNumber, nil
}
func sysProductName() (string, error) {
var dst []Win32_ComputerSystemProduct
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
// `ComputerSystemProduct` could be empty on some virtualized systems
if len(dst) < 1 {
return "unknown", nil
}
return dst[0].Name, nil
}
func sysManufacturer() (string, error) {
var dst []Win32_ComputerSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].Manufacturer, nil
}

View File

@@ -3,12 +3,7 @@
package system package system
import ( import (
"context"
"sync" "sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
) )
var ( var (
@@ -16,25 +11,26 @@ var (
once sync.Once once sync.Once
) )
func updateStaticInfo() StaticInfo { // StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
// Windows specific fields
OSName string
OSVersion string
BuildVersion string
}
func updateStaticInfo() {
once.Do(func() { once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) staticInfo = newStaticInfo()
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
}) })
}
func getStaticInfo() StaticInfo {
updateStaticInfo()
return staticInfo return staticInfo
} }

Some files were not shown because too many files have changed in this diff Show More