Compare commits

...

37 Commits

Author SHA1 Message Date
dependabot[bot]
20bb4095f2 Bump github.com/docker/docker
Bumps [github.com/docker/docker](https://github.com/docker/docker) from 26.1.5+incompatible to 28.0.0+incompatible.
- [Release notes](https://github.com/docker/docker/releases)
- [Commits](https://github.com/docker/docker/compare/v26.1.5...v28.0.0)

---
updated-dependencies:
- dependency-name: github.com/docker/docker
  dependency-version: 28.0.0+incompatible
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-27 16:32:02 +00:00
Vlad
99bd34c02a [signal] fix goroutines and memory leak on forward messages between peers (#3896) 2025-08-27 19:30:49 +03:00
Krzysztof Nazarewski (kdn)
7ce5507c05 [client] fix darwin dns always throwing err (#4403)
* fix: dns/host_darwin.go was missing if err != nil before throwing error
2025-08-27 09:59:39 +02:00
Pascal Fischer
0320bb7b35 [management] Report sync duration and login duration by accountID (#4406) 2025-08-26 22:32:12 +02:00
Viktor Liu
f063866ce8 [client] Add flag to configure MTU (#4213) 2025-08-26 16:00:14 +02:00
plusls
9f84165763 [client] Add netstack support for Android cli (#4319) 2025-08-26 15:40:01 +02:00
Pascal Fischer
3488a516c9 [management] Move increment network serial as last step of each transaction (#4397) 2025-08-25 17:27:07 +02:00
Pascal Fischer
5e273c121a [management] Remove store locks 3 (#4390) 2025-08-21 20:47:28 +02:00
Bethuel Mmbaga
968d95698e [management] Bump github.com/golang-jwt/jwt from 3.2.2+incompatible to 5.3.0 (#4375) 2025-08-21 15:02:51 +03:00
Pascal Fischer
28bef26537 [management] Remove Store Locks 2 (#4385) 2025-08-21 12:23:49 +02:00
Pascal Fischer
0d2845ea31 [management] optimize proxy network map (#4324) 2025-08-20 19:04:19 +02:00
Zoltan Papp
f425870c8e [client] Avoid duplicated agent close (#4383) 2025-08-20 18:50:51 +02:00
Pascal Fischer
f9d64a06c2 [management] Remove all store locks from grpc side (#4374) 2025-08-20 12:41:14 +02:00
hakansa
86555c44f7 refactor doc workflow (#4373)
refactor doc workflow (#4373)
2025-08-20 10:59:32 +03:00
Bastien Jeannelle
48792c64cd [misc] Fix confusing comment (#4376) 2025-08-20 00:12:00 +02:00
hakansa
533d93eb17 [management,client] Feat/exit node auto apply (#4272)
[management,client] Feat/exit node auto apply (#4272)
2025-08-19 18:19:24 +03:00
dependabot[bot]
9685411246 [misc] Bump golang.org/x/oauth2 from 0.24.0 to 0.27.0 (#4176)
Bumps [golang.org/x/oauth2](https://github.com/golang/oauth2) from 0.24.0 to 0.27.0
2025-08-19 16:26:46 +03:00
hakansa
d00a226556 [management] Add CreatedAt field to Peer and PeerBatch models (#4371)
[management] Add CreatedAt field to Peer and PeerBatch models (#4371)
2025-08-19 16:02:11 +03:00
Pascal Fischer
5d361b5421 [management] add nil handling for route domains (#4366) 2025-08-19 11:35:03 +02:00
dependabot[bot]
a889c4108b [misc] Bump github.com/containerd/containerd from 1.7.16 to 1.7.27 (#3527)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.16 to 1.7.27
2025-08-18 21:57:21 +03:00
Zoltan Papp
12cad854b2 [client] Fix/ice handshake (#4281)
In this PR, speed up the GRPC message processing, force the recreation of the ICE agent when getting a new, remote offer (do not wait for local STUN timeout).
2025-08-18 20:09:50 +02:00
Pascal Fischer
6a3846a8b7 [management] Remove save account calls (#4349) 2025-08-18 12:37:20 +02:00
Viktor Liu
7cd5dcae59 [client] Fix rule order for deny rules in peer ACLs (#4147) 2025-08-18 11:17:00 +02:00
Pascal Fischer
0e62325d46 [management] fail on geo location init failure (#4362) 2025-08-18 10:53:55 +02:00
Pascal Fischer
b3056d0937 [management] Use DI containers for server bootstrapping (#4343) 2025-08-15 17:14:48 +02:00
Zoltan Papp
ab853ac2a5 [server] Add MySQL initialization script and update Docker configuration (#4345) 2025-08-14 17:53:59 +02:00
Misha Bragin
e97f853909 Improve wording in the NetBird client app (#4316) 2025-08-13 22:03:48 +02:00
hakansa
70db8751d7 [client] Add --disable-update-settings flag to the service (#4335)
[client] Add --disable-update-settings flag to the service (#4335)
2025-08-13 21:05:12 +03:00
Zoltan Papp
86a00ab4af Fix Go tarball version in FreeBSD build configuration (#4339) 2025-08-13 13:52:11 +02:00
Zoltan Papp
3d4b502126 [server] Add health check HTTP endpoint for Relay server (#4297)
The health check endpoint listens on a dedicated HTTP server.
By default, it is available at 0.0.0.0:9000/health. This can be configured using the --health-listen-address flag.

The results are cached for 3 seconds to avoid excessive calls.

The health check performs the following:

Checks the number of active listeners.
Validates each listener via WebSocket and QUIC dials, including TLS certificate verification.
2025-08-13 10:40:04 +02:00
Bethuel Mmbaga
a4e8647aef [management] Enable flow groups (#4230)
Adds the ability to limit traffic events logging to specific peer groups
2025-08-13 00:00:40 +03:00
Viktor Liu
160b811e21 [client] Distinguish between NXDOMAIN and NODATA in the dns forwarder (#4321) 2025-08-12 15:59:42 +02:00
Viktor Liu
5e607cf4e9 [client] Skip dns upstream servers pointing to our dns server IP to prevent loops (#4330) 2025-08-12 15:41:23 +02:00
Viktor Liu
0fdb944058 [client] Create NRPT rules separately per domain (#4329) 2025-08-12 15:40:37 +02:00
Zoltan Papp
ccbabd9e2a Add pprof support for Relay server (#4203) 2025-08-12 12:24:24 +02:00
Pascal Fischer
a942e4add5 [management] use readlock on add peer (#4308) 2025-08-11 15:21:26 +02:00
Viktor Liu
1022a5015c [client] Eliminate upstream server strings in dns code (#4267) 2025-08-11 11:57:21 +02:00
223 changed files with 5191 additions and 2538 deletions

View File

@@ -16,19 +16,29 @@ jobs:
steps: steps:
- name: Read PR body - name: Read PR body
id: body id: body
shell: bash
run: | run: |
BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH") set -euo pipefail
echo "body<<EOF" >> $GITHUB_OUTPUT BODY_B64=$(jq -r '.pull_request.body // "" | @base64' "$GITHUB_EVENT_PATH")
echo "$BODY" >> $GITHUB_OUTPUT {
echo "EOF" >> $GITHUB_OUTPUT echo "body_b64=$BODY_B64"
} >> "$GITHUB_OUTPUT"
- name: Validate checkbox selection - name: Validate checkbox selection
id: validate id: validate
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: | run: |
body='${{ steps.body.outputs.body }}' set -euo pipefail
if ! body="$(printf '%s' "$BODY_B64" | base64 -d)"; then
echo "::error::Failed to decode PR body from base64. Data may be corrupted or missing."
exit 1
fi
added_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*I added/updated documentation' | wc -l | tr -d '[:space:]' || true)
noneed_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*Documentation is \*\*not needed\*\*' | wc -l | tr -d '[:space:]' || true)
added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ')
noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ')
if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then
echo "::error::Choose exactly one: either 'docs added' OR 'not needed'." echo "::error::Choose exactly one: either 'docs added' OR 'not needed'."
@@ -41,30 +51,35 @@ jobs:
fi fi
if [ "$added_checked" -eq 1 ]; then if [ "$added_checked" -eq 1 ]; then
echo "mode=added" >> $GITHUB_OUTPUT echo "mode=added" >> "$GITHUB_OUTPUT"
else else
echo "mode=noneed" >> $GITHUB_OUTPUT echo "mode=noneed" >> "$GITHUB_OUTPUT"
fi fi
- name: Extract docs PR URL (when 'docs added') - name: Extract docs PR URL (when 'docs added')
if: steps.validate.outputs.mode == 'added' if: steps.validate.outputs.mode == 'added'
id: extract id: extract
shell: bash
env:
BODY_B64: ${{ steps.body.outputs.body_b64 }}
run: | run: |
body='${{ steps.body.outputs.body }}' set -euo pipefail
body="$(printf '%s' "$BODY_B64" | base64 -d)"
# Strictly require HTTPS and that it's a PR in netbirdio/docs # Strictly require HTTPS and that it's a PR in netbirdio/docs
# Examples accepted: # e.g., https://github.com/netbirdio/docs/pull/1234
# https://github.com/netbirdio/docs/pull/1234 url="$(printf '%s' "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)"
url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)
if [ -z "$url" ]; then if [ -z "${url:-}" ]; then
echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)." echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)."
exit 1 exit 1
fi fi
pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#') pr_number="$(printf '%s' "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')"
echo "url=$url" >> $GITHUB_OUTPUT {
echo "pr_number=$pr_number" >> $GITHUB_OUTPUT echo "url=$url"
echo "pr_number=$pr_number"
} >> "$GITHUB_OUTPUT"
- name: Verify docs PR exists (and is open or merged) - name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added' if: steps.validate.outputs.mode == 'added'

View File

@@ -25,8 +25,7 @@ jobs:
release: "14.2" release: "14.2"
prepare: | prepare: |
pkg install -y curl pkgconf xorg pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1) GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL" GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL" curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL" tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -83,6 +83,15 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Setup MySQL privileges
if: matrix.store == 'mysql'
run: |
sleep 10
mysql -h 127.0.0.1 -u root -pmysqlroot -e "
GRANT SYSTEM_VARIABLES_ADMIN ON *.* TO 'netbird'@'%';
FLUSH PRIVILEGES;
"
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/

View File

@@ -4,6 +4,7 @@ package android
import ( import (
"context" "context"
"slices"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -112,7 +113,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -138,7 +139,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -235,7 +236,7 @@ func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
return err return err
} }
dnsServer.OnUpdatedHostDNSServer(list.items) dnsServer.OnUpdatedHostDNSServer(slices.Clone(list.items))
return nil return nil
} }

View File

@@ -1,23 +1,34 @@
package android package android
import "fmt" import (
"fmt"
"net/netip"
// DNSList is a wrapper of []string "github.com/netbirdio/netbird/client/internal/dns"
)
// DNSList is a wrapper of []netip.AddrPort with default DNS port
type DNSList struct { type DNSList struct {
items []string items []netip.AddrPort
} }
// Add new DNS address to the collection // Add new DNS address to the collection, returns error if invalid
func (array *DNSList) Add(s string) { func (array *DNSList) Add(s string) error {
array.items = append(array.items, s) addr, err := netip.ParseAddr(s)
if err != nil {
return fmt.Errorf("invalid DNS address: %s", s)
}
addrPort := netip.AddrPortFrom(addr.Unmap(), dns.DefaultPort)
array.items = append(array.items, addrPort)
return nil
} }
// Get return an element of the collection // Get return an element of the collection as string
func (array *DNSList) Get(i int) (string, error) { func (array *DNSList) Get(i int) (string, error) {
if i >= len(array.items) || i < 0 { if i >= len(array.items) || i < 0 {
return "", fmt.Errorf("out of range") return "", fmt.Errorf("out of range")
} }
return array.items[i], nil return array.items[i].Addr().String(), nil
} }
// Size return with the size of the collection // Size return with the size of the collection

View File

@@ -3,20 +3,30 @@ package android
import "testing" import "testing"
func TestDNSList_Get(t *testing.T) { func TestDNSList_Get(t *testing.T) {
l := DNSList{ l := DNSList{}
items: make([]string, 1),
// Add a valid DNS address
err := l.Add("8.8.8.8")
if err != nil {
t.Errorf("unexpected error: %s", err)
} }
_, err := l.Get(0) // Test getting valid index
addr, err := l.Get(0)
if err != nil { if err != nil {
t.Errorf("invalid error: %s", err) t.Errorf("invalid error: %s", err)
} }
if addr != "8.8.8.8" {
t.Errorf("expected 8.8.8.8, got %s", addr)
}
// Test negative index
_, err = l.Get(-1) _, err = l.Get(-1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")
} }
// Test out of bounds index
_, err = l.Get(1) _, err = l.Get(1)
if err == nil { if err == nil {
t.Errorf("expected error but got nil") t.Errorf("expected error but got nil")

View File

@@ -33,7 +33,7 @@ var (
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the NetBird daemon.", Long: "Commands for debugging and logging within the NetBird daemon.",
} }
var debugBundleCmd = &cobra.Command{ var debugBundleCmd = &cobra.Command{

View File

@@ -14,7 +14,8 @@ import (
var downCmd = &cobra.Command{ var downCmd = &cobra.Command{
Use: "down", Use: "down",
Short: "down netbird connections", Short: "Disconnect from the NetBird network",
Long: "Disconnect the NetBird client from the network and management service. This will terminate all active connections with the remote peers.",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)

View File

@@ -31,7 +31,8 @@ func init() {
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the NetBird Management Service (first run)", Short: "Log in to the NetBird network",
Long: "Log in to the NetBird network using a setup key or SSO",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setEnvAndFlags(cmd); err != nil { if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err) return fmt.Errorf("set env and flags: %v", err)

View File

@@ -14,7 +14,8 @@ import (
var logoutCmd = &cobra.Command{ var logoutCmd = &cobra.Command{
Use: "deregister", Use: "deregister",
Aliases: []string{"logout"}, Aliases: []string{"logout"},
Short: "deregister from the NetBird Management Service and delete peer", Short: "Deregister from the NetBird management service and delete this peer",
Long: "This command will deregister the current peer from the NetBird management service and all associated configuration. Use with caution as this will remove the peer from the network.",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)

View File

@@ -15,7 +15,7 @@ var appendFlag bool
var networksCMD = &cobra.Command{ var networksCMD = &cobra.Command{
Use: "networks", Use: "networks",
Aliases: []string{"routes"}, Aliases: []string{"routes"},
Short: "Manage networks", Short: "Manage connections to NetBird Networks and Resources",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
} }

View File

@@ -16,13 +16,13 @@ import (
var profileCmd = &cobra.Command{ var profileCmd = &cobra.Command{
Use: "profile", Use: "profile",
Short: "manage NetBird profiles", Short: "Manage NetBird client profiles",
Long: `Manage NetBird profiles, allowing you to list, switch, and remove profiles.`, Long: `Commands to list, add, remove, and switch profiles. Profiles allow you to maintain different accounts in one client app.`,
} }
var profileListCmd = &cobra.Command{ var profileListCmd = &cobra.Command{
Use: "list", Use: "list",
Short: "list all profiles", Short: "List all profiles",
Long: `List all available profiles in the NetBird client.`, Long: `List all available profiles in the NetBird client.`,
Aliases: []string{"ls"}, Aliases: []string{"ls"},
RunE: listProfilesFunc, RunE: listProfilesFunc,
@@ -30,7 +30,7 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{ var profileAddCmd = &cobra.Command{
Use: "add <profile_name>", Use: "add <profile_name>",
Short: "add a new profile", Short: "Add a new profile",
Long: `Add a new profile to the NetBird client. The profile name must be unique.`, Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: addProfileFunc, RunE: addProfileFunc,
@@ -38,16 +38,16 @@ var profileAddCmd = &cobra.Command{
var profileRemoveCmd = &cobra.Command{ var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>", Use: "remove <profile_name>",
Short: "remove a profile", Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be active.`, Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: removeProfileFunc, RunE: removeProfileFunc,
} }
var profileSelectCmd = &cobra.Command{ var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>", Use: "select <profile_name>",
Short: "select a profile", Short: "Select a profile",
Long: `Select a profile to be the active profile in the NetBird client. The profile must exist.`, Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: selectProfileFunc, RunE: selectProfileFunc,
} }

View File

@@ -39,6 +39,7 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection" enableLazyConnectionFlag = "enable-lazy-connection"
mtuFlag = "mtu"
) )
var ( var (
@@ -72,7 +73,9 @@ var (
anonymizeFlag bool anonymizeFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
lazyConnEnabled bool lazyConnEnabled bool
mtu uint16
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",

View File

@@ -54,6 +54,7 @@ func TestSetFlagsFromEnvVars(t *testing.T) {
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
cmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name") t.Setenv("NB_INTERFACE_NAME", "test-name")

View File

@@ -19,7 +19,7 @@ import (
var serviceCmd = &cobra.Command{ var serviceCmd = &cobra.Command{
Use: "service", Use: "service",
Short: "manages NetBird service", Short: "Manage the NetBird daemon service",
} }
var ( var (
@@ -42,7 +42,8 @@ func init() {
} }
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.") serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` + serviceEnvDesc := `Sets extra environment variables for the service. ` +

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
} }
} }
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled) serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
if err := serverInstance.Start(); err != nil { if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err) log.Fatalf("failed to start daemon: %v", err)
} }

View File

@@ -49,6 +49,14 @@ func buildServiceArguments() []string {
args = append(args, "--log-file", logFile) args = append(args, "--log-file", logFile)
} }
if profilesDisabled {
args = append(args, "--disable-profiles")
}
if updateSettingsDisabled {
args = append(args, "--disable-update-settings")
}
return args return args
} }
@@ -99,7 +107,7 @@ func createServiceConfigForInstall() (*service.Config, error) {
var installCmd = &cobra.Command{ var installCmd = &cobra.Command{
Use: "install", Use: "install",
Short: "installs NetBird service", Short: "Install NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err

View File

@@ -40,7 +40,7 @@ var sshCmd = &cobra.Command{
return nil return nil
}, },
Short: "connect to a remote SSH server", Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd) SetFlagsFromEnvVars(cmd)

View File

@@ -32,7 +32,8 @@ var (
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
Use: "status", Use: "status",
Short: "status of the Netbird Service", Short: "Display NetBird client status",
Long: "Display the current status of the NetBird client, including connection status, peer information, and network details.",
RunE: statusFunc, RunE: statusFunc,
} }

View File

@@ -10,7 +10,9 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -26,15 +28,15 @@ import (
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server" sig "github.com/netbirdio/netbird/signal/server"
) )
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
t.Helper() t.Helper()
config := &types.Config{} config := &config.Config{}
_, err := util.ReadJson("../testdata/management.json", config) _, err := util.ReadJson("../testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -69,7 +71,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *config.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
@@ -97,6 +99,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
@@ -108,7 +111,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -134,7 +137,7 @@ func startClientDaemon(
s := grpc.NewServer() s := grpc.NewServer()
server := client.New(ctx, server := client.New(ctx,
"", "", false) "", "", false, false)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -53,7 +53,8 @@ var (
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start NetBird client", Short: "Connect to the NetBird network",
Long: "Connect to the NetBird network using the provided setup key or SSO auth. This command will bring up the WireGuard interface, connect to the management server, and establish peer-to-peer connections with other peers in the network if required.",
RunE: upFunc, RunE: upFunc,
} }
) )
@@ -62,6 +63,7 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port")
upCmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
@@ -356,6 +358,11 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.WireguardPort = &p req.WireguardPort = &p
} }
if cmd.Flag(mtuFlag).Changed {
m := int64(mtu)
req.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed { if cmd.Flag(networkMonitorFlag).Changed {
req.NetworkMonitor = &networkMonitor req.NetworkMonitor = &networkMonitor
} }
@@ -435,6 +442,13 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.WireguardPort = &p ic.WireguardPort = &p
} }
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
ic.MTU = &mtu
}
if cmd.Flag(networkMonitorFlag).Changed { if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor ic.NetworkMonitor = &networkMonitor
} }
@@ -532,6 +546,14 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.WireguardPort = &wp loginRequest.WireguardPort = &wp
} }
if cmd.Flag(mtuFlag).Changed {
if err := iface.ValidateMTU(mtu); err != nil {
return nil, err
}
m := int64(mtu)
loginRequest.Mtu = &m
}
if cmd.Flag(networkMonitorFlag).Changed { if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor loginRequest.NetworkMonitor = &networkMonitor
} }

View File

@@ -9,7 +9,7 @@ import (
var ( var (
versionCmd = &cobra.Command{ versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "prints NetBird version", Short: "Print the NetBird's client application version",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
cmd.Println(version.NetbirdVersion()) cmd.Println(version.NetbirdVersion())

View File

@@ -85,7 +85,7 @@ func (m *aclManager) AddPeerFiltering(
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
chain := chainNameInputRules chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort) ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs) mangleSpecs := slices.Clone(specs)
@@ -135,7 +135,14 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil { // Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err return nil, err
} }
@@ -388,17 +395,25 @@ func actionToStr(action firewall.Action) string {
return "DROP" return "DROP"
} }
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string { func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
switch { if ipsetName == "" {
case ipsetName == "":
return "" return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil: case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil: case sPort != nil:
return ipsetName + "-sport" return ipsetName + "-sport" + actionSuffix
case dPort != nil: case dPort != nil:
return ipsetName + "-dport" return ipsetName + "-dport" + actionSuffix
default: default:
return ipsetName return ipsetName + actionSuffix
} }
} }

View File

@@ -3,6 +3,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
@@ -15,7 +16,7 @@ import (
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
@@ -109,10 +110,84 @@ func TestIptablesManager(t *testing.T) {
}) })
} }
func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
defer func() {
err := manager.Close(nil)
require.NoError(t, err)
}()
t.Run("add deny rule", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
require.NoError(t, err, "failed to add deny rule")
require.NotEmpty(t, rule, "deny rule should not be empty")
// Verify the rule was added by checking iptables
for _, r := range rule {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("deny rule precedence test", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.4")
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
rules, err := ipv4Client.List("filter", chainNameInputRules)
require.NoError(t, err, "failed to list iptables rules")
// Debug: print all rules
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
for i, rule := range rules {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex int = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
}
if strings.Contains(rule, "ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
}
}
require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in iptables")
require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in iptables")
require.Less(t, denyRuleIndex, acceptRuleIndex,
"deny rule should come before accept rule in iptables chain (deny at index %d, accept at index %d)",
denyRuleIndex, acceptRuleIndex)
})
}
func TestIptablesManagerIPSet(t *testing.T) { func TestIptablesManagerIPSet(t *testing.T) {
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
@@ -176,7 +251,7 @@ func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName strin
func TestIptablesCreatePerformance(t *testing.T) { func TestIptablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{

View File

@@ -341,30 +341,38 @@ func (m *AclManager) addIOFiltering(
userData := []byte(ruleId) userData := []byte(ruleId)
chain := m.chainInputRules chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{ rule := &nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Exprs: mainExpressions, Exprs: mainExpressions,
UserData: userData, UserData: userData,
}) }
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf(flushError, err)
} }
rule := &Rule{ ruleStruct := &Rule{
nftRule: nftRule, nftRule: nftRule,
mangleRule: m.createPreroutingRule(expressions, userData), mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset, nftSet: ipset,
ruleID: ruleId, ruleID: ruleId,
ip: ip, ip: ip,
} }
m.rules[ruleId] = rule m.rules[ruleId] = ruleStruct
if ipset != nil { if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name) m.ipsetStore.AddReferenceToIpset(ipset.Name)
} }
return rule, nil return ruleStruct, nil
} }
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule { func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {

View File

@@ -2,6 +2,7 @@ package nftables
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"net/netip" "net/netip"
"os/exec" "os/exec"
@@ -20,7 +21,7 @@ import (
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
@@ -103,9 +104,8 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) // Since DROP rules are inserted at position 0, the DROP rule comes first
expectedDropExprs := []expr.Any{
expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -141,7 +141,12 @@ func TestNftablesManager(t *testing.T) {
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
// Compare DROP rule at position 0 (inserted first due to InsertRule)
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedDropExprs)
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
for _, r := range rule { for _, r := range rule {
err = manager.DeletePeerRule(r) err = manager.DeletePeerRule(r)
@@ -160,10 +165,90 @@ func TestNftablesManager(t *testing.T) {
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
defer func() {
err = manager.Close(nil)
require.NoError(t, err)
}()
ip := netip.MustParseAddr("100.96.0.2").Unmap()
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex int = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var action string
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
if lookup.SetName == "accept-http" {
hasAcceptHTTPSet = true
} else if lookup.SetName == "deny-http" {
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
if verdict.Kind == expr.VerdictAccept {
action = "ACCEPT"
} else if verdict.Kind == expr.VerdictDrop {
action = "DROP"
}
}
}
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
acceptRuleIndex = i
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
denyRuleIndex = i
}
}
require.NotEqual(t, -1, acceptRuleIndex, "accept rule should exist in nftables")
require.NotEqual(t, -1, denyRuleIndex, "deny rule should exist in nftables")
require.Less(t, denyRuleIndex, acceptRuleIndex,
"deny rule should come before accept rule in nftables chain (deny at index %d, accept at index %d)",
denyRuleIndex, acceptRuleIndex)
}
func TestNFtablesCreatePerformance(t *testing.T) { func TestNFtablesCreatePerformance(t *testing.T) {
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "wg-test"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{

View File

@@ -18,6 +18,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {

View File

@@ -27,6 +27,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {

View File

@@ -70,14 +70,13 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
// outgoingRules is used for hooks only outgoingRules map[netip.Addr]RuleSet
outgoingRules map[netip.Addr]RuleSet incomingDenyRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks incomingRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet routeRules RouteRules
routeRules RouteRules decoders sync.Pool
decoders sync.Pool wgIface common.IFaceMapper
wgIface common.IFaceMapper nativeFirewall firewall.Manager
nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
@@ -186,6 +185,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}, },
nativeFirewall: nativeFirewall, nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet), outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet), incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface, wgIface: iface,
localipmanager: newLocalIPManager(), localipmanager: newLocalIPManager(),
@@ -417,10 +417,17 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip]; !ok { var targetMap map[netip.Addr]RuleSet
m.incomingRules[r.ip] = make(RuleSet) if r.drop {
targetMap = m.incomingDenyRules
} else {
targetMap = m.incomingRules
} }
m.incomingRules[r.ip][r.id] = r
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
@@ -507,10 +514,24 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if _, ok := m.incomingRules[r.ip][r.id]; !ok { var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
} else {
sourceMap = m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.incomingRules[r.ip], r.id)
return nil return nil
} }
@@ -572,7 +593,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// FilterOutBound filters outgoing packets // FilterOutbound filters outgoing packets
func (m *Manager) FilterOutbound(packetData []byte, size int) bool { func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.filterOutbound(packetData, size) return m.filterOutbound(packetData, size)
} }
@@ -761,7 +782,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked { if blocked {
_, pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
@@ -971,26 +992,28 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return nil, false return nil, false
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
// Default policy: DROP ALL
return nil, true return nil, true
} }
@@ -1013,6 +1036,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 { if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue continue
@@ -1045,6 +1069,7 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true return rule.mgmtId, rule.drop, true
} }
} }
return nil, false, false return nil, false, false
} }
@@ -1116,6 +1141,7 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
m.mutex.Lock() m.mutex.Lock()
if in { if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule) m.incomingRules[r.ip] = make(map[string]PeerRule)
} }
@@ -1136,6 +1162,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules { for _, arr := range m.incomingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
@@ -1144,6 +1171,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
} }
} }
// Check outgoing hooks
for _, arr := range m.outgoingRules { for _, arr := range m.outgoingRules {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {

View File

@@ -458,6 +458,31 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionDrop, ruleAction: fw.ActionDrop,
shouldBeBlocked: true, shouldBeBlocked: true,
}, },
{
name: "Peer ACL - Drop rule should override accept all rule",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 22,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{22}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Peer ACL - Drop all traffic from specific IP",
srcIP: "100.10.0.99",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.99",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
} }
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
@@ -468,13 +493,11 @@ func TestPeerACLFiltering(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop { if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule // add general accept rule for the same IP to test drop rule precedence
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil, nil,
net.ParseIP("0.0.0.0"), net.ParseIP(tc.ruleIP),
fw.ProtocolALL, fw.ProtocolALL,
nil, nil,
nil, nil,

View File

@@ -136,9 +136,22 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
// Check rules exist in appropriate maps
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok { peerRule, ok := r.(*PeerRule)
t.Errorf("rule2 is not in the incomingRules") if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
} }
} }
@@ -150,9 +163,22 @@ func TestManagerDeleteRule(t *testing.T) {
} }
} }
// Check rules are removed from appropriate maps
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok { peerRule, ok := r.(*PeerRule)
t.Errorf("rule2 is not in the incomingRules") if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
} else {
_, found = m.incomingRules[ip][r.ID()]
}
if found {
t.Errorf("rule2 should be removed from the rules map")
} }
} }
} }
@@ -196,16 +222,17 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return return
} }
for _, rule := range manager.incomingRules[tt.ip] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
if len(manager.outgoingRules) != 1 { if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip] { for _, rule := range manager.outgoingRules[tt.ip] {
@@ -261,8 +288,8 @@ func TestManagerReset(t *testing.T) {
return return
} }
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 { if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules is not empty") t.Errorf("rules are not empty")
} }
} }

View File

@@ -314,7 +314,7 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleId, blocked := m.peerACLsBlock(srcIP, d, packetData)
strRuleId := "<no id>" strRuleId := "<no id>"
if ruleId != nil { if ruleId != nil {

View File

@@ -56,10 +56,11 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -69,6 +70,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
mtu: mtu,
address: address, address: address,
activityRecorder: NewActivityRecorder(), activityRecorder: NewActivityRecorder(),
} }
@@ -80,6 +82,10 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
return ib return ib
} }
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false s.closed = false
s.closedChanMu.Lock() s.closedChanMu.Lock()
@@ -158,6 +164,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,
MTU: s.mtu,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {

View File

@@ -18,6 +18,7 @@ import (
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -44,6 +45,7 @@ type UniversalUDPMuxParams struct {
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address WGAddress wgaddr.Address
MTU uint16
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -84,7 +86,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// just ignore other packets printing an warning message. // just ignore other packets printing an warning message.
// It is a blocking method, consider running in a go routine. // It is a blocking method, consider running in a go routine.
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
buf := make([]byte, 1500) buf := make([]byte, m.params.MTU+bufsize.WGBufferOverhead)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -0,0 +1,9 @@
package bufsize
const (
// WGBufferOverhead represents the additional buffer space needed beyond MTU
// for WireGuard packet encapsulation (WG header + UDP + IP + safety margin)
// Original hardcoded buffers were 1500, default MTU is 1280, so overhead = 220
// TODO: Calculate this properly based on actual protocol overhead instead of using hardcoded difference
WGBufferOverhead = 220
)

View File

@@ -17,6 +17,7 @@ type WGTunDevice interface {
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -21,7 +21,7 @@ type WGTunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool disableDNS bool
@@ -33,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -58,7 +58,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
searchDomainsToString = "" searchDomainsToString = ""
} }
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)
return nil, err return nil, err
@@ -137,6 +137,10 @@ func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *WGTunDevice) MTU() uint16 {
return t.mtu
}
func (t *WGTunDevice) FilteredDevice() *FilteredDevice { func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }

View File

@@ -21,7 +21,7 @@ type TunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
@@ -30,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -42,7 +42,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
} }
func (t *TunDevice) Create() (WGConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu) tunDevice, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
@@ -111,6 +111,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -22,6 +22,7 @@ type TunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunFd int tunFd int
@@ -31,12 +32,13 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunFd: tunFd, tunFd: tunFd,
} }
@@ -125,6 +127,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement // todo implement
return nil return nil

View File

@@ -24,7 +24,7 @@ type TunKernelDevice struct {
address wgaddr.Address address wgaddr.Address
wgPort int wgPort int
key string key string
mtu int mtu uint16
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
transportNet transport.Net transportNet transport.Net
@@ -36,7 +36,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -66,7 +66,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
// TODO: do a MTU discovery // TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
if err := link.setMTU(t.mtu); err != nil { if err := link.setMTU(int(t.mtu)); err != nil {
return nil, fmt.Errorf("set mtu: %w", err) return nil, fmt.Errorf("set mtu: %w", err)
} }
@@ -96,7 +96,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter()) rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter(), t.mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -111,6 +111,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,
MTU: t.mtu,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -158,6 +159,10 @@ func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunKernelDevice) MTU() uint16 {
return t.mtu
}
func (t *TunKernelDevice) DeviceName() string { func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -1,6 +1,3 @@
//go:build !android
// +build !android
package device package device
import ( import (
@@ -22,7 +19,7 @@ type TunNetstackDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
listenAddress string listenAddress string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -35,7 +32,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -47,7 +44,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
} }
} }
func (t *TunNetstackDevice) Create() (WGConfigurer, error) { func (t *TunNetstackDevice) create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
@@ -57,7 +54,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
} }
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu))
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create() tunIface, net, err := t.nsTun.Create()
if err != nil { if err != nil {
@@ -125,6 +122,10 @@ func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunNetstackDevice) MTU() uint16 {
return t.mtu
}
func (t *TunNetstackDevice) DeviceName() string { func (t *TunNetstackDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -0,0 +1,7 @@
//go:build android
package device
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}

View File

@@ -0,0 +1,7 @@
//go:build !android
package device
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
return t.create()
}

View File

@@ -20,7 +20,7 @@ type USPDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
@@ -29,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &USPDevice{
@@ -44,9 +44,9 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu
func (t *USPDevice) Create() (WGConfigurer, error) { func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu) tunIface, err := tun.CreateTUN(t.name, int(t.mtu))
if err != nil { if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) log.Debugf("failed to create tun interface (%s, %d): %s", t.name, int(t.mtu), err)
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.filteredDevice = newDeviceFilter(tunIface) t.filteredDevice = newDeviceFilter(tunIface)
@@ -118,6 +118,10 @@ func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *USPDevice) MTU() uint16 {
return t.mtu
}
func (t *USPDevice) DeviceName() string { func (t *USPDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -23,7 +23,7 @@ type TunDevice struct {
address wgaddr.Address address wgaddr.Address
port int port int
key string key string
mtu int mtu uint16
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
@@ -33,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -59,7 +59,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, err return nil, err
} }
log.Info("create tun interface") log.Info("create tun interface")
tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu) tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, int(t.mtu))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
@@ -144,6 +144,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) MTU() uint16 {
return t.mtu
}
func (t *TunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }

View File

@@ -15,6 +15,7 @@ type WGTunDevice interface {
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -26,6 +26,8 @@ import (
const ( const (
DefaultMTU = 1280 DefaultMTU = 1280
MinMTU = 576
MaxMTU = 8192
DefaultWgPort = 51820 DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
@@ -35,6 +37,17 @@ var (
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found") ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
) )
// ValidateMTU validates that MTU is within acceptable range
func ValidateMTU(mtu uint16) error {
if mtu < MinMTU {
return fmt.Errorf("MTU %d below minimum (%d bytes)", mtu, MinMTU)
}
if mtu > MaxMTU {
return fmt.Errorf("MTU %d exceeds maximum supported size (%d bytes)", mtu, MaxMTU)
}
return nil
}
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -45,7 +58,7 @@ type WGIFaceOpts struct {
Address string Address string
WGPort int WGPort int
WGPrivKey string WGPrivKey string
MTU int MTU uint16
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
@@ -82,6 +95,10 @@ func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
func (w *WGIface) MTU() uint16 {
return w.tun.MTU()
}
// ToInterface returns the net.Interface for the Wireguard interface // ToInterface returns the net.Interface for the Wireguard interface
func (r *WGIface) ToInterface() *net.Interface { func (r *WGIface) ToInterface() *net.Interface {
name := r.tun.DeviceName() name := r.tun.DeviceName()

View File

@@ -3,6 +3,7 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -14,7 +15,16 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,

View File

@@ -17,7 +17,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -16,10 +16,10 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }

View File

@@ -22,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{} wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -31,11 +31,11 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if device.WireGuardModuleIsLoaded() { if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort) wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU)
return wgIFace, nil return wgIFace, nil
} }
if device.ModuleTunIsLoaded() { if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -14,7 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
@@ -135,7 +136,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}() }()
for { for {
buf := make([]byte, 1500) buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -17,6 +17,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -29,6 +30,7 @@ const (
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
mtu uint16
ebpfManager ebpfMgr.Manager ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn turnConnStore map[uint16]net.Conn
@@ -43,10 +45,11 @@ type WGEBPFProxy struct {
} }
// NewWGEBPFProxy create new WGEBPFProxy instance // NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu,
ebpfManager: ebpf.GetEbpfManagerInstance(), ebpfManager: ebpf.GetEbpfManagerInstance(),
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
@@ -138,7 +141,7 @@ func (p *WGEBPFProxy) Free() error {
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance. // From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for p.ctx.Err() == nil { for p.ctx.Err() == nil {
if err := p.readAndForwardPacket(buf); err != nil { if err := p.readAndForwardPacket(buf); err != nil {
if p.ctx.Err() != nil { if p.ctx.Err() != nil {

View File

@@ -7,7 +7,7 @@ import (
) )
func TestWGEBPFProxy_connStore(t *testing.T) { func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(1) wgProxy := NewWGEBPFProxy(1, 1280)
p, _ := wgProxy.storeTurnConn(nil) p, _ := wgProxy.storeTurnConn(nil)
if p != 1 { if p != 1 {
@@ -27,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(1) wgProxy := NewWGEBPFProxy(1, 1280)
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535 wgProxy.lastUsedPort = 65535
@@ -43,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(1) wgProxy := NewWGEBPFProxy(1, 1280)
for i := 0; i < 65535; i++ { for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
@@ -103,7 +104,7 @@ func (e *ProxyWrapper) CloseConn() error {
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500) buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.readFromRemote(ctx, buf) n, err := p.readFromRemote(ctx, buf)
if err != nil { if err != nil {

View File

@@ -11,16 +11,18 @@ import (
type KernelFactory struct { type KernelFactory struct {
wgPort int wgPort int
mtu uint16
ebpfProxy *ebpf.WGEBPFProxy ebpfProxy *ebpf.WGEBPFProxy
} }
func NewKernelFactory(wgPort int) *KernelFactory { func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
f := &KernelFactory{ f := &KernelFactory{
wgPort: wgPort, wgPort: wgPort,
mtu: mtu,
} }
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
@@ -33,7 +35,7 @@ func NewKernelFactory(wgPort int) *KernelFactory {
func (w *KernelFactory) GetProxy() Proxy { func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil { if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
} }
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)

View File

@@ -9,19 +9,21 @@ import (
// KernelFactory todo: check eBPF support on FreeBSD // KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct { type KernelFactory struct {
wgPort int wgPort int
mtu uint16
} }
func NewKernelFactory(wgPort int) *KernelFactory { func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{ f := &KernelFactory{
wgPort: wgPort, wgPort: wgPort,
mtu: mtu,
} }
return f return f
} }
func (w *KernelFactory) GetProxy() Proxy { func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -16,7 +16,7 @@ func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) t.Fatalf("failed to initialize ebpf proxy: %s", err)
} }

View File

@@ -84,12 +84,12 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}{ }{
{ {
name: "userspace proxy", name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830), proxy: udpProxy.NewWGUDPProxy(51830, 1280),
}, },
} }
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) t.Fatalf("failed to initialize ebpf proxy: %s", err)
} }

View File

@@ -12,12 +12,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
type WGUDPProxy struct { type WGUDPProxy struct {
localWGListenPort int localWGListenPort int
mtu uint16
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
@@ -34,10 +36,11 @@ type WGUDPProxy struct {
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
func NewWGUDPProxy(wgPort int) *WGUDPProxy { func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu,
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
return p return p
@@ -144,7 +147,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
} }
}() }()
buf := make([]byte, 1500) buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for ctx.Err() == nil { for ctx.Err() == nil {
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
if err != nil { if err != nil {
@@ -179,7 +182,7 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
} }
}() }()
buf := make([]byte, 1500) buf := make([]byte, p.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.remoteConnRead(ctx, buf) n, err := p.remoteConnRead(ctx, buf)
if err != nil { if err != nil {

View File

@@ -3,15 +3,17 @@ package auth
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/client/internal"
"github.com/stretchr/testify/require"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
) )
type mockHTTPClient struct { type mockHTTPClient struct {

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@@ -17,6 +18,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -70,7 +72,7 @@ func (c *ConnectClient) RunOnAndroid(
tunAdapter device.TunAdapter, tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil
@@ -243,7 +245,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
}
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager) c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 { if len(relayURLs) > 0 {
if token != nil { if token != nil {
@@ -258,14 +268,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
} }
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
if err != nil {
log.Error(err)
return wrapErr(err)
}
checks := loginResp.GetChecks() checks := loginResp.GetChecks()
c.engineMutex.Lock() c.engineMutex.Lock()
@@ -443,6 +445,8 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
BlockInbound: config.BlockInbound, BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled, LazyConnectionEnabled: config.LazyConnectionEnabled,
MTU: selectMTU(config.MTU, peerConfig.Mtu),
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -465,6 +469,20 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
return engineConf, nil return engineConf, nil
} }
func selectMTU(localMTU uint16, peerMTU int32) uint16 {
var finalMTU uint16 = iface.DefaultMTU
if localMTU > 0 {
finalMTU = localMTU
} else if peerMTU > 0 {
finalMTU = uint16(peerMTU)
}
// Set global DNS MTU
dns.SetCurrentMTU(finalMTU)
return finalMTU
}
// connectToSignal creates Signal Service client and established a connection // connectToSignal creates Signal Service client and established a connection
func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) { func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) {
var sigTLSEnabled bool var sigTLSEnabled bool

View File

@@ -16,7 +16,7 @@ const (
) )
type resolvConf struct { type resolvConf struct {
nameServers []string nameServers []netip.Addr
searchDomains []string searchDomains []string
others []string others []string
} }
@@ -36,7 +36,7 @@ func parseBackupResolvConf() (*resolvConf, error) {
func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) { func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
rconf := &resolvConf{ rconf := &resolvConf{
searchDomains: make([]string, 0), searchDomains: make([]string, 0),
nameServers: make([]string, 0), nameServers: make([]netip.Addr, 0),
others: make([]string, 0), others: make([]string, 0),
} }
@@ -94,7 +94,11 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
if len(splitLines) != 2 { if len(splitLines) != 2 {
continue continue
} }
rconf.nameServers = append(rconf.nameServers, splitLines[1]) if addr, err := netip.ParseAddr(splitLines[1]); err == nil {
rconf.nameServers = append(rconf.nameServers, addr.Unmap())
} else {
log.Warnf("invalid nameserver address in resolv.conf: %s, skipping", splitLines[1])
}
continue continue
} }
@@ -104,31 +108,3 @@ func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
} }
return rconf, nil return rconf, nil
} }
// removeFirstNbNameserver removes the given nameserver from the given file if it is in the first position
// and writes the file back to the original location
func removeFirstNbNameserver(filename string, nameserverIP netip.Addr) error {
resolvConf, err := parseResolvConfFile(filename)
if err != nil {
return fmt.Errorf("parse backup resolv.conf: %w", err)
}
content, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
if len(resolvConf.nameServers) > 1 && resolvConf.nameServers[0] == nameserverIP.String() {
newContent := strings.Replace(string(content), fmt.Sprintf("nameserver %s\n", nameserverIP), "", 1)
stat, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("stat %s: %w", filename, err)
}
if err := os.WriteFile(filename, []byte(newContent), stat.Mode()); err != nil {
return fmt.Errorf("write %s: %w", filename, err)
}
}
return nil
}

View File

@@ -3,13 +3,9 @@
package dns package dns
import ( import (
"net/netip"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseResolvConf(t *testing.T) { func Test_parseResolvConf(t *testing.T) {
@@ -99,9 +95,13 @@ options debug
t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains) t.Errorf("invalid parse result for search domains, expected: %v, got: %v", testCase.expectedSearch, cfg.searchDomains)
} }
ok = compareLists(cfg.nameServers, testCase.expectedNS) nsStrings := make([]string, len(cfg.nameServers))
for i, ns := range cfg.nameServers {
nsStrings[i] = ns.String()
}
ok = compareLists(nsStrings, testCase.expectedNS)
if !ok { if !ok {
t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, cfg.nameServers) t.Errorf("invalid parse result for ns domains, expected: %v, got: %v", testCase.expectedNS, nsStrings)
} }
ok = compareLists(cfg.others, testCase.expectedOther) ok = compareLists(cfg.others, testCase.expectedOther)
@@ -176,87 +176,3 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg) t.Errorf("unexpected resolv.conf content: %v", cfg)
} }
} }
func TestRemoveFirstNbNameserver(t *testing.T) {
testCases := []struct {
name string
content string
ipToRemove string
expected string
}{
{
name: "Unrelated nameservers with comments and options",
content: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
ipToRemove: "9.9.9.9",
expected: `# This is a comment
options rotate
nameserver 1.1.1.1
# Another comment
nameserver 8.8.4.4
search example.com`,
},
{
name: "First nameserver matches",
content: `search example.com
nameserver 9.9.9.9
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
ipToRemove: "9.9.9.9",
expected: `search example.com
# oof, a comment
nameserver 8.8.4.4
options attempts:5`,
},
{
name: "Target IP not the first nameserver",
// nolint:dupword
content: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
ipToRemove: "9.9.9.9",
// nolint:dupword
expected: `# Comment about the first nameserver
nameserver 8.8.4.4
# Comment before our target
nameserver 9.9.9.9
options timeout:2`,
},
{
name: "Only nameserver matches",
content: `options debug
nameserver 9.9.9.9
search localdomain`,
ipToRemove: "9.9.9.9",
expected: `options debug
nameserver 9.9.9.9
search localdomain`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := t.TempDir()
tempFile := filepath.Join(tempDir, "resolv.conf")
err := os.WriteFile(tempFile, []byte(tc.content), 0644)
assert.NoError(t, err)
ip, err := netip.ParseAddr(tc.ipToRemove)
require.NoError(t, err, "Failed to parse IP address")
err = removeFirstNbNameserver(tempFile, ip)
assert.NoError(t, err)
content, err := os.ReadFile(tempFile)
assert.NoError(t, err)
assert.Equal(t, tc.expected, string(content), "The resulting content should match the expected output.")
})
}
}

View File

@@ -146,7 +146,7 @@ func isNbParamsMissing(nbSearchDomains []string, nbNameserverIP netip.Addr, rCon
return true return true
} }
if rConf.nameServers[0] != nbNameserverIP.String() { if rConf.nameServers[0] != nbNameserverIP {
return true return true
} }

View File

@@ -29,7 +29,7 @@ type fileConfigurator struct {
repair *repair repair *repair
originalPerms os.FileMode originalPerms os.FileMode
nbNameserverIP netip.Addr nbNameserverIP netip.Addr
originalNameservers []string originalNameservers []netip.Addr
} }
func newFileConfigurator() (*fileConfigurator, error) { func newFileConfigurator() (*fileConfigurator, error) {
@@ -70,7 +70,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
} }
// getOriginalNameservers returns the nameservers that were found in the original resolv.conf // getOriginalNameservers returns the nameservers that were found in the original resolv.conf
func (f *fileConfigurator) getOriginalNameservers() []string { func (f *fileConfigurator) getOriginalNameservers() []netip.Addr {
return f.originalNameservers return f.originalNameservers
} }
@@ -128,20 +128,14 @@ func (f *fileConfigurator) backup() error {
} }
func (f *fileConfigurator) restore() error { func (f *fileConfigurator) restore() error {
err := removeFirstNbNameserver(fileDefaultResolvConfBackupLocation, f.nbNameserverIP) if err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath); err != nil {
if err != nil {
log.Errorf("Failed to remove netbird nameserver from %s on backup restore: %s", fileDefaultResolvConfBackupLocation, err)
}
err = copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath)
if err != nil {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress netip.Addr) error {
resolvConf, err := parseDefaultResolvConf() resolvConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
return fmt.Errorf("parse current resolv.conf: %w", err) return fmt.Errorf("parse current resolv.conf: %w", err)
@@ -152,16 +146,9 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
currentDNSAddress, err := netip.ParseAddr(resolvConf.nameServers[0])
// not a valid first nameserver -> restore
if err != nil {
log.Errorf("restoring unclean shutdown: parse dns address %s failed: %s", resolvConf.nameServers[0], err)
return restoreResolvConfFile()
}
// current address is still netbird's non-available dns address -> restore // current address is still netbird's non-available dns address -> restore
// comparing parsed addresses only, to remove ambiguity currentDNSAddress := resolvConf.nameServers[0]
if currentDNSAddress.String() == storedDNSAddress.String() { if currentDNSAddress == storedDNSAddress {
return restoreResolvConfFile() return restoreResolvConfFile()
} }

View File

@@ -166,9 +166,10 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error { func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true) if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration") log.Errorf("Unable to get system DNS configuration")
return err return fmt.Errorf("recordSystemDNSSettings(): %w", err)
}
} }
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
@@ -239,7 +240,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address inServerAddressesArray = false // Stop reading after finding the first IPv4 address
} }
} }
@@ -250,7 +251,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
// default to 53 port // default to 53 port
dnsSettings.ServerPort = defaultPort dnsSettings.ServerPort = DefaultPort
return dnsSettings, nil return dnsSettings, nil
} }

View File

@@ -42,7 +42,7 @@ func (t osManagerType) String() string {
type restoreHostManager interface { type restoreHostManager interface {
hostManager hostManager
restoreUncleanShutdownDNS(*netip.Addr) error restoreUncleanShutdownDNS(netip.Addr) error
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
@@ -130,8 +130,9 @@ func checkStub() bool {
return true return true
} }
systemdResolvedAddr := netip.AddrFrom4([4]byte{127, 0, 0, 53}) // 127.0.0.53
for _, ns := range rConf.nameServers { for _, ns := range rConf.nameServers {
if ns == "127.0.0.53" { if ns == systemdResolvedAddr {
return true return true
} }
} }

View File

@@ -64,9 +64,10 @@ const (
) )
type registryConfigurator struct { type registryConfigurator struct {
guid string guid string
routingAll bool routingAll bool
gpo bool gpo bool
nrptEntryCount int
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil { if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err) log.Errorf("failed to update shutdown state: %s", err)
} }
@@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
return fmt.Errorf("add dns match policy: %w", err) return fmt.Errorf("add dns match policy: %w", err)
} }
r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil { if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err) return fmt.Errorf("remove dns match policies: %w", err)
} }
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
@@ -216,32 +232,38 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
} }
r.routingAll = true r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) log.Infof("configured %s:%d as main DNS forwarder for this peer", ip, DefaultPort)
return nil return nil
} }
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
if r.gpo { for i, domain := range domains {
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil { policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
return fmt.Errorf("configure GPO DNS policy: %w", err) if r.gpo {
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
} }
singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
}
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err) log.Warnf("failed to refresh group policy: %v", err)
} }
} else {
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure local DNS policy: %w", err)
}
} }
log.Infof("added %d match domains. Domain list: %s", len(domains), domains) log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return nil return len(domains), nil
} }
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
@@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
func (r *registryConfigurator) removeDNSMatchPolicies() error { func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error var merr *multierror.Error
// Try to remove the base entries (for backward compatibility)
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { for i := 0; i < r.nrptEntryCount; i++ {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err)) localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
}
} }
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {

View File

@@ -1,38 +1,31 @@
package dns package dns
import ( import (
"fmt"
"net/netip" "net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
type hostsDNSHolder struct { type hostsDNSHolder struct {
unprotectedDNSList map[string]struct{} unprotectedDNSList map[netip.AddrPort]struct{}
mutex sync.RWMutex mutex sync.RWMutex
} }
func newHostsDNSHolder() *hostsDNSHolder { func newHostsDNSHolder() *hostsDNSHolder {
return &hostsDNSHolder{ return &hostsDNSHolder{
unprotectedDNSList: make(map[string]struct{}), unprotectedDNSList: make(map[netip.AddrPort]struct{}),
} }
} }
func (h *hostsDNSHolder) set(list []string) { func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Lock() h.mutex.Lock()
h.unprotectedDNSList = make(map[string]struct{}) h.unprotectedDNSList = make(map[netip.AddrPort]struct{})
for _, dns := range list { for _, addrPort := range list {
dnsAddr, err := h.normalizeAddress(dns) h.unprotectedDNSList[addrPort] = struct{}{}
if err != nil {
continue
}
h.unprotectedDNSList[dnsAddr] = struct{}{}
} }
h.mutex.Unlock() h.mutex.Unlock()
} }
func (h *hostsDNSHolder) get() map[string]struct{} { func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock() h.mutex.RLock()
l := h.unprotectedDNSList l := h.unprotectedDNSList
h.mutex.RUnlock() h.mutex.RUnlock()
@@ -40,24 +33,10 @@ func (h *hostsDNSHolder) get() map[string]struct{} {
} }
//nolint:unused //nolint:unused
func (h *hostsDNSHolder) isContain(upstream string) bool { func (h *hostsDNSHolder) contains(upstream netip.AddrPort) bool {
h.mutex.RLock() h.mutex.RLock()
defer h.mutex.RUnlock() defer h.mutex.RUnlock()
_, ok := h.unprotectedDNSList[upstream] _, ok := h.unprotectedDNSList[upstream]
return ok return ok
} }
func (h *hostsDNSHolder) normalizeAddress(addr string) (string, error) {
a, err := netip.ParseAddr(addr)
if err != nil {
log.Errorf("invalid upstream IP address: %s, error: %s", addr, err)
return "", err
}
if a.Is4() {
return fmt.Sprintf("%s:53", addr), nil
} else {
return fmt.Sprintf("[%s]:53", addr), nil
}
}

View File

@@ -50,7 +50,7 @@ func (m *MockServer) DnsIP() netip.Addr {
return netip.MustParseAddr("100.10.254.255") return netip.MustParseAddr("100.10.254.255")
} }
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) { func (m *MockServer) OnUpdatedHostDNSServer(addrs []netip.AddrPort) {
// TODO implement me // TODO implement me
panic("implement me") panic("implement me")
} }

View File

@@ -245,7 +245,7 @@ func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error {
return nil return nil
} }
func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (n *networkManagerDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := n.restoreHostDNS(); err != nil { if err := n.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via network-manager: %w", err) return fmt.Errorf("restoring dns via network-manager: %w", err)
} }

View File

@@ -40,7 +40,7 @@ type resolvconf struct {
implType resolvconfType implType resolvconfType
originalSearchDomains []string originalSearchDomains []string
originalNameServers []string originalNameServers []netip.Addr
othersConfigs []string othersConfigs []string
} }
@@ -110,7 +110,7 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
return nil return nil
} }
func (r *resolvconf) getOriginalNameservers() []string { func (r *resolvconf) getOriginalNameservers() []netip.Addr {
return r.originalNameServers return r.originalNameServers
} }
@@ -158,7 +158,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
return nil return nil
} }
func (r *resolvconf) restoreUncleanShutdownDNS(*netip.Addr) error { func (r *resolvconf) restoreUncleanShutdownDNS(netip.Addr) error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err) return fmt.Errorf("restoring dns for interface %s: %w", r.ifaceName, err)
} }

View File

@@ -42,7 +42,7 @@ type Server interface {
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
UpdateDNSServer(serial uint64, update nbdns.Config) error UpdateDNSServer(serial uint64, update nbdns.Config) error
OnUpdatedHostDNSServer(strings []string) OnUpdatedHostDNSServer(addrs []netip.AddrPort)
SearchDomains() []string SearchDomains() []string
ProbeAvailability() ProbeAvailability()
} }
@@ -55,7 +55,7 @@ type nsGroupsByDomain struct {
// hostManagerWithOriginalNS extends the basic hostManager interface // hostManagerWithOriginalNS extends the basic hostManager interface
type hostManagerWithOriginalNS interface { type hostManagerWithOriginalNS interface {
hostManager hostManager
getOriginalNameservers() []string getOriginalNameservers() []netip.Addr
} }
// DefaultServer dns server object // DefaultServer dns server object
@@ -136,7 +136,7 @@ func NewDefaultServer(
func NewDefaultServerPermanentUpstream( func NewDefaultServerPermanentUpstream(
ctx context.Context, ctx context.Context,
wgInterface WGIface, wgInterface WGIface,
hostsDnsList []string, hostsDnsList []netip.AddrPort,
config nbdns.Config, config nbdns.Config,
listener listener.NetworkChangeListener, listener listener.NetworkChangeListener,
statusRecorder *peer.Status, statusRecorder *peer.Status,
@@ -144,6 +144,7 @@ func NewDefaultServerPermanentUpstream(
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@@ -340,7 +341,7 @@ func (s *DefaultServer) disableDNS() error {
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones // OnUpdatedHostDNSServer update the DNS servers addresses for root zones
// It will be applied if the mgm server do not enforce DNS settings for root zone // It will be applied if the mgm server do not enforce DNS settings for root zone
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) {
s.hostsDNSHolder.set(hostsDnsList) s.hostsDNSHolder.set(hostsDnsList)
// Check if there's any root handler // Check if there's any root handler
@@ -461,7 +462,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() { if s.service.RuntimePort() != DefaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " + log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver") "Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
@@ -581,14 +582,13 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
} }
for _, ns := range originalNameservers { for _, ns := range originalNameservers {
if ns == config.ServerIP.String() { if ns == config.ServerIP {
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
continue continue
} }
ns = formatAddr(ns, defaultPort) addrPort := netip.AddrPortFrom(ns, DefaultPort)
handler.upstreamServers = append(handler.upstreamServers, addrPort)
handler.upstreamServers = append(handler.upstreamServers, ns)
} }
handler.deactivate = func(error) { /* always active */ } handler.deactivate = func(error) { /* always active */ }
handler.reactivate = func() { /* always active */ } handler.reactivate = func() { /* always active */ }
@@ -695,7 +695,13 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue continue
} }
handler.upstreamServers = append(handler.upstreamServers, getNSHostPort(ns))
if ns.IP == s.service.RuntimeIP() {
log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP)
continue
}
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
} }
if len(handler.upstreamServers) == 0 { if len(handler.upstreamServers) == 0 {
@@ -770,18 +776,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap s.dnsMuxMap = muxUpdateMap
} }
func getNSHostPort(ns nbdns.NameServer) string {
return formatAddr(ns.IP.String(), ns.Port)
}
// formatAddr formats a nameserver address with port, handling IPv6 addresses properly
func formatAddr(address string, port int) string {
if ip, err := netip.ParseAddr(address); err == nil && ip.Is6() {
return fmt.Sprintf("[%s]:%d", address, port)
}
return fmt.Sprintf("%s:%d", address, port)
}
// upstreamCallbacks returns two functions, the first one is used to deactivate // upstreamCallbacks returns two functions, the first one is used to deactivate
// the upstream resolver from the configuration, the second one is used to // the upstream resolver from the configuration, the second one is used to
// reactivate it. Not allowed to call reactivate before deactivate. // reactivate it. Not allowed to call reactivate before deactivate.
@@ -879,10 +873,7 @@ func (s *DefaultServer) addHostRootZone() {
return return
} }
handler.upstreamServers = make([]string, 0) handler.upstreamServers = maps.Keys(hostDNSServers)
for k := range hostDNSServers {
handler.upstreamServers = append(handler.upstreamServers, k)
}
handler.deactivate = func(error) {} handler.deactivate = func(error) {}
handler.reactivate = func() {} handler.reactivate = func() {}
@@ -893,9 +884,9 @@ func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {
var states []peer.NSGroupState var states []peer.NSGroupState
for _, group := range groups { for _, group := range groups {
var servers []string var servers []netip.AddrPort
for _, ns := range group.NameServers { for _, ns := range group.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort())
} }
state := peer.NSGroupState{ state := peer.NSGroupState{
@@ -927,7 +918,7 @@ func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error,
func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { func generateGroupKey(nsGroup *nbdns.NameServerGroup) string {
var servers []string var servers []string
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
servers = append(servers, fmt.Sprintf("%s:%d", ns.IP, ns.Port)) servers = append(servers, ns.AddrPort().String())
} }
return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) return fmt.Sprintf("%v_%v", servers, nsGroup.Domains)
} }

View File

@@ -97,9 +97,9 @@ func init() {
} }
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []string var srvs []netip.AddrPort
for _, srv := range servers { for _, srv := range servers {
srvs = append(srvs, getNSHostPort(srv)) srvs = append(srvs, srv.AddrPort())
} }
return &upstreamResolverBase{ return &upstreamResolverBase{
domain: domain, domain: domain,
@@ -705,7 +705,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
var dnsList []string var dnsList []netip.AddrPort
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false) dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
@@ -715,7 +715,8 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
} }
defer dnsServer.Stop() defer dnsServer.Stop()
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"}) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer.OnUpdatedHostDNSServer([]netip.AddrPort{addrPort})
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort()) resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
_, err = resolver.LookupHost(context.Background(), "netbird.io") _, err = resolver.LookupHost(context.Background(), "netbird.io")
@@ -731,7 +732,8 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -823,7 +825,8 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
} }
defer wgIFace.Close() defer wgIFace.Close()
dnsConfig := nbdns.Config{} dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, peer.NewRecorder("mgm"), false) addrPort := netip.MustParseAddrPort("8.8.8.8:53")
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []netip.AddrPort{addrPort}, dnsConfig, nil, peer.NewRecorder("mgm"), false)
err = dnsServer.Initialize() err = dnsServer.Initialize()
if err != nil { if err != nil {
t.Errorf("failed to initialize DNS server: %v", err) t.Errorf("failed to initialize DNS server: %v", err)
@@ -2054,55 +2057,123 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
} }
func TestFormatAddr(t *testing.T) { func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
dnsServerIP := service.RuntimeIP()
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct { tests := []struct {
name string name string
address string nsGroups []*nbdns.NameServerGroup
port int expectedHandlers int
expected string expectedServers []netip.Addr
shouldFilterOwnIP bool
}{ }{
{ {
name: "IPv4 address", name: "FilterOwnDNSServerIP",
address: "8.8.8.8", nsGroups: []*nbdns.NameServerGroup{
port: 53, {
expected: "8.8.8.8:53", Primary: true,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
}, },
{ {
name: "IPv4 address with custom port", name: "AllServersFiltered",
address: "1.1.1.1", nsGroups: []*nbdns.NameServerGroup{
port: 5353, {
expected: "1.1.1.1:5353", Primary: false,
NameServers: []nbdns.NameServer{
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
},
expectedHandlers: 0,
expectedServers: []netip.Addr{},
shouldFilterOwnIP: true,
}, },
{ {
name: "IPv6 address", name: "MixedServersWithOwnIP",
address: "fd78:94bf:7df8::1", nsGroups: []*nbdns.NameServerGroup{
port: 53, {
expected: "[fd78:94bf:7df8::1]:53", Primary: false,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
},
Domains: []string{"test.com"},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
}, },
{ {
name: "IPv6 address with custom port", name: "NoOwnIPInList",
address: "2001:db8::1", nsGroups: []*nbdns.NameServerGroup{
port: 5353, {
expected: "[2001:db8::1]:5353", Primary: true,
}, NameServers: []nbdns.NameServer{
{ {IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
name: "IPv6 localhost", {IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
address: "::1", },
port: 53, Domains: []string{},
expected: "[::1]:53", },
}, },
{ expectedHandlers: 1,
name: "Invalid address treated as hostname", expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
address: "dns.example.com", shouldFilterOwnIP: false,
port: 53,
expected: "dns.example.com:53",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := formatAddr(tt.address, tt.port) muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
assert.Equal(t, tt.expected, result) assert.NoError(t, err)
assert.Len(t, muxUpdates, tt.expectedHandlers)
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
if upstream.Addr() == expected {
found = true
break
}
}
assert.True(t, found, "Expected server %s not found", expected)
}
}
}) })
} }
} }

View File

@@ -7,7 +7,7 @@ import (
) )
const ( const (
defaultPort = 53 DefaultPort = 53
) )
type service interface { type service interface {

View File

@@ -122,7 +122,7 @@ func (s *serviceViaListener) RuntimePort() int {
defer s.listenerFlagLock.Unlock() defer s.listenerFlagLock.Unlock()
if s.ebpfService != nil { if s.ebpfService != nil {
return defaultPort return DefaultPort
} else { } else {
return int(s.listenPort) return int(s.listenPort)
} }
@@ -148,9 +148,9 @@ func (s *serviceViaListener) evalListenAddress() (netip.Addr, uint16, error) {
return s.customAddr.Addr(), s.customAddr.Port(), nil return s.customAddr.Addr(), s.customAddr.Port(), nil
} }
ip, ok := s.testFreePort(defaultPort) ip, ok := s.testFreePort(DefaultPort)
if ok { if ok {
return ip, defaultPort, nil return ip, DefaultPort, nil
} }
ebpfSrv, port, ok := s.tryToUseeBPF() ebpfSrv, port, ok := s.tryToUseeBPF()

View File

@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: lastIP, runtimeIP: lastIP,
runtimePort: defaultPort, runtimePort: DefaultPort,
} }
return s return s
} }

View File

@@ -235,7 +235,7 @@ func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error
return nil return nil
} }
func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemdDbusConfigurator) restoreUncleanShutdownDNS(netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via systemd: %w", err) return fmt.Errorf("restoring dns via systemd: %w", err)
} }

View File

@@ -27,7 +27,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create previous host manager: %w", err) return fmt.Errorf("create previous host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { if err := manager.restoreUncleanShutdownDNS(s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -5,8 +5,9 @@ import (
) )
type ShutdownState struct { type ShutdownState struct {
Guid string Guid string
GPO bool GPO bool
NRPTEntryCount int
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -15,8 +16,9 @@ func (s *ShutdownState) Name() string {
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
manager := &registryConfigurator{ manager := &registryConfigurator{
guid: s.Guid, guid: s.Guid,
gpo: s.GPO, gpo: s.GPO,
nrptEntryCount: s.NRPTEntryCount,
} }
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {

View File

@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -25,6 +26,12 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
var currentMTU uint16 = iface.DefaultMTU
func SetCurrentMTU(mtu uint16) {
currentMTU = mtu
}
const ( const (
UpstreamTimeout = 15 * time.Second UpstreamTimeout = 15 * time.Second
@@ -48,7 +55,7 @@ type upstreamResolverBase struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []netip.AddrPort
domain string domain string
disabled bool disabled bool
failsCount atomic.Int32 failsCount atomic.Int32
@@ -79,17 +86,20 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
// String returns a string representation of the upstream resolver // String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string { func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers) return fmt.Sprintf("upstream %s", u.upstreamServers)
} }
// ID returns the unique handler ID // ID returns the unique handler ID
func (u *upstreamResolverBase) ID() types.HandlerID { func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers) servers := slices.Clone(u.upstreamServers)
slices.Sort(servers) slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New() hash := sha256.New()
hash.Write([]byte(u.domain + ":")) hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ","))) for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
} }
@@ -130,7 +140,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func() { func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel() defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream, r) rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}() }()
if err != nil { if err != nil {
@@ -197,7 +207,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)", "All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta // TODO add domain meta
) )
} }
@@ -258,7 +268,7 @@ func (u *upstreamResolverBase) ProbeAvailability() {
proto.SystemEvent_DNS, proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)", "All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": strings.Join(u.upstreamServers, ", ")}, map[string]string{"upstreams": u.upstreamServersString()},
) )
} }
} }
@@ -278,7 +288,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
operation := func() error { operation := func() error {
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServers)) return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default: default:
} }
@@ -291,7 +301,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
} }
} }
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error") return fmt.Errorf("upstream check call error")
} }
@@ -301,7 +311,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return return
} }
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers) log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
@@ -331,13 +341,21 @@ func (u *upstreamResolverBase) disable(err error) {
go u.waitUntilResponse() go u.waitUntilResponse()
} }
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error { func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout) ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel() defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(ctx, server, r) _, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err return err
} }
@@ -346,8 +364,8 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
// If the passed context is nil, this will use Exchange instead of ExchangeContext. // If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers // MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower. // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = iface.DefaultMTU - (60 + 8) client.UDPSize = uint16(currentMTU - (60 + 8))
var ( var (
rm *dns.Msg rm *dns.Msg

View File

@@ -79,8 +79,8 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
} }
func (u *upstreamResolver) isLocalResolver(upstream string) bool { func (u *upstreamResolver) isLocalResolver(upstream string) bool {
if u.hostsDNSHolder.isContain(upstream) { if addrPort, err := netip.ParseAddrPort(upstream); err == nil {
return true return u.hostsDNSHolder.contains(addrPort)
} }
return false return false
} }

View File

@@ -62,6 +62,8 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP, err := netip.ParseAddr(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if err != nil { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err) log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
} else {
upstreamIP = upstreamIP.Unmap()
} }
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() { if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)

View File

@@ -59,7 +59,14 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers // Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
if addrPort, err := netip.ParseAddrPort(server); err == nil {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {
cancel() cancel()
@@ -128,7 +135,8 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivatePeriod: reactivatePeriod, reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact, failsTillDeact: failsTillDeact,
} }
resolver.upstreamServers = []string{"0.0.0.0:-1"} addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
resolver.failsTillDeact = 0 resolver.failsTillDeact = 0
resolver.reactivatePeriod = time.Microsecond * 100 resolver.reactivatePeriod = time.Microsecond * 100

View File

@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
defer cancel() defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(ctx, w, question, resp, domain, err)
return nil return nil
} }
@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
} }
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError var dnsErr *net.DNSError
switch { switch {
case errors.As(err, &dnsErr): case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound { if dnsErr.IsNotFound {
// Pass through NXDOMAIN f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
resp.Rcode = dns.RcodeNameError
} }
if dnsErr.Server != "" { if dnsErr.Server != "" {
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }

View File

@@ -3,6 +3,7 @@ package dnsfwd
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
@@ -16,8 +17,8 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
assert.Len(t, matches, 3, "Should match 3 patterns") assert.Len(t, matches, 3, "Should match 3 patterns")
} }
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
forwarder.UpdateDomains(entries)
tests := []struct {
name string
queryType uint16
setupMocks func()
expectedCode int
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
description string
}{
{
name: "domain exists but no AAAA records (NODATA)",
queryType: dns.TypeAAAA,
setupMocks: func() {
// First query for AAAA returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for A records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no records of requested type",
},
{
name: "domain exists but no A records (NODATA)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no A records",
},
{
name: "domain doesn't exist (NXDOMAIN)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA also returns not found (domain doesn't exist)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
},
expectedCode: dns.RcodeNameError,
expectNoAnswer: true,
description: "Should return NXDOMAIN when domain doesn't exist at all",
},
{
name: "domain exists with records (normal success)",
queryType: dns.TypeA,
setupMocks: func() {
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
// Expect firewall update for successful resolution
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: false,
description: "Should return NOERROR with answer when records exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset mock expectations
mockResolver.ExpectedCalls = nil
mockResolver.Calls = nil
mockFirewall.ExpectedCalls = nil
mockFirewall.Calls = nil
tt.setupMocks()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
})
}
}
func TestDNSForwarder_EmptyQuery(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions // Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})

View File

@@ -55,11 +55,11 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto" sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -125,6 +125,8 @@ type EngineConfig struct {
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
MTU uint16
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -254,6 +256,7 @@ func NewEngine(
} }
engine.stateManager = statemanager.New(path) engine.stateManager = statemanager.New(path)
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine return engine
} }
@@ -346,6 +349,10 @@ func (e *Engine) Start() error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if err := iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err)
}
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
} }
@@ -1110,15 +1117,16 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
} }
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix.Masked(), Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains), Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer, Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric), Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade, Masquerade: protoRoute.Masquerade,
KeepRoute: protoRoute.KeepRoute, KeepRoute: protoRoute.KeepRoute,
SkipAutoApply: protoRoute.SkipAutoApply,
} }
routes = append(routes, convertedRoute) routes = append(routes, convertedRoute)
} }
@@ -1330,52 +1338,17 @@ func (e *Engine) receiveSignalEvents() {
} }
switch msg.GetBody().Type { switch msg.GetBody().Type {
case sProto.Body_OFFER: case sProto.Body_OFFER, sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg) offerAnswer, err := convertToOfferAnswer(msg)
if err != nil { if err != nil {
return err return err
} }
var rosenpassPubKey []byte if msg.Body.Type == sProto.Body_OFFER {
rosenpassAddr := "" conn.OnRemoteOffer(*offerAnswer)
if msg.GetBody().GetRosenpassConfig() != nil { } else {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey() conn.OnRemoteAnswer(*offerAnswer)
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
} }
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE: case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil { if err != nil {
@@ -1525,7 +1498,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
Address: e.config.WgAddr, Address: e.config.WgAddr,
WGPort: e.config.WgPort, WGPort: e.config.WgPort,
WGPrivKey: e.config.WgPrivateKey.String(), WGPrivKey: e.config.WgPrivateKey.String(),
MTU: iface.DefaultMTU, MTU: e.config.MTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS, DisableDNS: e.config.DisableDNS,
@@ -2073,3 +2046,44 @@ func createFile(path string) error {
} }
return file.Close() return file.Close()
} }
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return nil, err
}
var (
rosenpassPubKey []byte
rosenpassAddr string
)
if cfg := msg.GetBody().GetRosenpassConfig(); cfg != nil {
rosenpassPubKey = cfg.GetRosenpassPubKey()
rosenpassAddr = cfg.GetRosenpassServerAddr()
}
// Handle optional SessionID
var sessionID *peer.ICESessionID
if sessionBytes := msg.GetBody().GetSessionId(); sessionBytes != nil {
if id, err := peer.ICESessionIDFromBytes(sessionBytes); err != nil {
log.Warnf("Invalid session ID in message: %v", err)
sessionID = nil // Set to nil if conversion fails
} else {
sessionID = &id
}
}
offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
SessionID: sessionID,
}
return &offerAnswer, nil
}

View File

@@ -27,6 +27,8 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@@ -43,8 +45,6 @@ import (
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -54,8 +54,10 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
@@ -216,7 +218,7 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine( engine := NewEngine(
ctx, cancel, ctx, cancel,
&signal.MockClient{}, &signal.MockClient{},
@@ -228,6 +230,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
ServerSSHAllowed: true, ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
}, },
MobileDependency{}, MobileDependency{},
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
@@ -361,7 +364,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine( engine := NewEngine(
ctx, cancel, ctx, cancel,
&signal.MockClient{}, &signal.MockClient{},
@@ -372,6 +375,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, },
MobileDependency{}, MobileDependency{},
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
@@ -410,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
@@ -587,12 +591,13 @@ func TestEngine_Sync(t *testing.T) {
} }
return nil return nil
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103", WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
@@ -751,12 +756,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName, WgIfaceName: wgIfaceName,
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
@@ -952,12 +958,13 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName, WgIfaceName: wgIfaceName,
WgAddr: wgAddr, WgAddr: wgAddr,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx engine.ctx = ctx
@@ -1179,6 +1186,7 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
config: &EngineConfig{ config: &EngineConfig{
IFaceBlackList: testCase.inputBlacklistInterface, IFaceBlackList: testCase.inputBlacklistInterface,
NATExternalIPs: testCase.inputMapList, NATExternalIPs: testCase.inputMapList,
MTU: iface.DefaultMTU,
}, },
} }
parsedList := engine.parseNATExternalIPMappings() parsedList := engine.parseNATExternalIPMappings()
@@ -1479,9 +1487,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgAddr: resp.PeerConfig.Address, WgAddr: resp.PeerConfig.Address,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: wgPort, WgPort: wgPort,
MTU: iface.DefaultMTU,
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx e.ctx = ctx
return e, err return e, err
@@ -1513,15 +1522,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper() t.Helper()
config := &types.Config{ config := &config.Config{
Stuns: []*types.Host{}, Stuns: []*config.Host{},
TURNConfig: &types.TURNConfig{}, TURNConfig: &config.TURNConfig{},
Relay: &types.Relay{ Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"}, Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour}, CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222", Secret: "222222222222222222",
}, },
Signal: &types.Host{ Signal: &config.Host{
Proto: "http", Proto: "http",
URI: "localhost:10000", URI: "localhost:10000",
}, },
@@ -1564,13 +1573,14 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@@ -1,6 +1,8 @@
package internal package internal
import ( import (
"net/netip"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -13,7 +15,7 @@ type MobileDependency struct {
TunAdapter device.TunAdapter TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string HostDNSAddresses []netip.AddrPort
DnsReadyListener dns.ReadyListener DnsReadyListener dns.ReadyListener
// iOS only // iOS only

View File

@@ -24,8 +24,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker" "github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -200,19 +200,11 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1) conn.wg.Add(1)
go func() { go func() {
defer conn.wg.Done() defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx) conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx) conn.semaphore.Done(conn.ctx)
conn.dumpState.SendOffer() conn.guard.Start(conn.ctx, conn.onGuardEvent)
if err := conn.handshaker.sendOffer(); err != nil {
conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.wg.Add(1)
go func() {
conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.wg.Done()
}()
}() }()
conn.opened = true conn.opened = true
return nil return nil
@@ -274,10 +266,10 @@ func (conn *Conn) Close(signalToRemote bool) {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) conn.handshaker.OnRemoteAnswer(answer)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -296,10 +288,10 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler conn.onDisconnected = handler
} }
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer() conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) conn.handshaker.OnRemoteOffer(offer)
} }
// WgConfig returns the WireGuard config // WgConfig returns the WireGuard config
@@ -548,7 +540,6 @@ func (conn *Conn) onRelayDisconnected() {
} }
func (conn *Conn) onGuardEvent() { func (conn *Conn) onGuardEvent() {
conn.Log.Debugf("send offer to peer")
conn.dumpState.SendOffer() conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil { if err := conn.handshaker.SendOffer(); err != nil {
conn.Log.Errorf("failed to send offer: %v", err) conn.Log.Errorf("failed to send offer: %v", err)
@@ -672,7 +663,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if conn.statusICE.Get() == worker.StatusDisconnected { if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false return false
} }

View File

@@ -1,9 +1,9 @@
package peer package peer
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
@@ -79,31 +79,30 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
wg := sync.WaitGroup{} onNewOffeChan := make(chan struct{})
wg.Add(2)
go func() {
<-conn.handshaker.remoteOffersCh
wg.Done()
}()
go func() { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
for { onNewOffeChan <- struct{}{}
accepted := conn.OnRemoteOffer(OfferAnswer{ })
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
return
}
}
}()
wg.Wait() conn.OnRemoteOffer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
@@ -119,31 +118,29 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
wg := sync.WaitGroup{} onNewOffeChan := make(chan struct{})
wg.Add(2)
go func() {
<-conn.handshaker.remoteAnswerCh
wg.Done()
}()
go func() { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
for { onNewOffeChan <- struct{}{}
accepted := conn.OnRemoteAnswer(OfferAnswer{ })
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
return
}
}
}()
wg.Wait() conn.OnRemoteAnswer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
} }
func TestConn_presharedKey(t *testing.T) { func TestConn_presharedKey(t *testing.T) {

View File

@@ -19,7 +19,6 @@ type isConnectedFunc func() bool
// - Relayed connection disconnected // - Relayed connection disconnected
// - ICE candidate changes // - ICE candidate changes
type Guard struct { type Guard struct {
Reconnect chan struct{}
log *log.Entry log *log.Entry
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
@@ -30,7 +29,6 @@ type Guard struct {
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
Reconnect: make(chan struct{}, 1),
log: log, log: log,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
timeout: timeout, timeout: timeout,
@@ -41,6 +39,7 @@ func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Durati
} }
func (g *Guard) Start(ctx context.Context, eventCallback func()) { func (g *Guard) Start(ctx context.Context, eventCallback func()) {
g.log.Infof("starting guard for reconnection with MaxInterval: %s", g.timeout)
g.reconnectLoopWithRetry(ctx, eventCallback) g.reconnectLoopWithRetry(ctx, eventCallback)
} }
@@ -61,17 +60,14 @@ func (g *Guard) SetICEConnDisconnected() {
// reconnectLoopWithRetry periodically check the connection status. // reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported // Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan) defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx) ticker := g.initialTicker(ctx)
defer ticker.Stop() defer ticker.Stop()
tickerChannel := ticker.C tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for { for {
select { select {
case t := <-tickerChannel: case t := <-tickerChannel:
@@ -85,7 +81,6 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
if !g.isConnectedOnAllWay() { if !g.isConnectedOnAllWay() {
callback() callback()
} }
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
@@ -111,6 +106,20 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
} }
} }
// initialTicker give chance to the peer to establish the initial connection.
func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 3 * time.Second,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
@@ -126,13 +135,3 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker return ticker
} }
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@@ -39,6 +39,15 @@ type OfferAnswer struct {
// relay server address // relay server address
RelaySrvAddress string RelaySrvAddress string
// SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID
}
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
} }
type Handshaker struct { type Handshaker struct {
@@ -74,21 +83,25 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
h.log.Info("wait for remote offer confirmation") select {
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx) case remoteOfferAnswer := <-h.remoteOffersCh:
if err != nil { // received confirmation from the remote peer -> ready to proceed
var connectionClosedError *ConnectionClosedError if err := h.sendAnswer(); err != nil {
if errors.As(err, &connectionClosedError) { h.log.Errorf("failed to send remote offer confirmation: %s", err)
h.log.Info("exit from handshaker") continue
return
} }
h.log.Errorf("failed to received remote offer confirmation: %s", err) for _, listener := range h.onNewOfferListeners {
continue listener(&remoteOfferAnswer)
} }
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort) case remoteOfferAnswer := <-h.remoteAnswerCh:
for _, listener := range h.onNewOfferListeners { h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
go listener(remoteOfferAnswer) for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")
return
} }
} }
} }
@@ -101,43 +114,27 @@ func (h *Handshaker) SendOffer() error {
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool { func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) {
select { select {
case h.remoteOffersCh <- offer: case h.remoteOffersCh <- offer:
return true return
default: default:
h.log.Warnf("OnRemoteOffer skipping message because is not ready") h.log.Warnf("skipping remote offer message because receiver not ready")
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
return false return
} }
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool { func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) {
select { select {
case h.remoteAnswerCh <- answer: case h.remoteAnswerCh <- answer:
return true return
default: default:
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready") h.log.Warnf("skipping remote answer message because receiver not ready")
return false return
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
if err := h.sendAnswer(); err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
} }
} }
@@ -147,43 +144,34 @@ func (h *Handshaker) sendOffer() error {
return ErrSignalIsNotReady return ErrSignalIsNotReady
} }
iceUFrag, icePwd := h.ice.GetLocalUserCredentials() offer := h.buildOfferAnswer()
offer := OfferAnswer{ h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
offer.RelaySrvAddress = addr
}
return h.signaler.SignalOffer(offer, h.config.Key) return h.signaler.SignalOffer(offer, h.config.Key)
} }
func (h *Handshaker) sendAnswer() error { func (h *Handshaker) sendAnswer() error {
h.log.Infof("sending answer") answer := h.buildOfferAnswer()
uFrag, pwd := h.ice.GetLocalUserCredentials() h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key)
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{ answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd}, IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort, WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(), Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr, RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
} }
addr, err := h.relay.RelayInstanceAddress()
if err == nil { if addr, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr answer.RelaySrvAddress = addr
} }
err = h.signaler.SignalAnswer(answer, h.config.Key) return answer
if err != nil {
return err
}
return nil
} }

View File

@@ -1,6 +1,7 @@
package ice package ice
import ( import (
"sync"
"time" "time"
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
@@ -23,7 +24,20 @@ const (
iceRelayAcceptanceMinWaitDefault = 2 * time.Second iceRelayAcceptanceMinWaitDefault = 2 * time.Second
) )
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { type ThreadSafeAgent struct {
*ice.Agent
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
err = a.Agent.Close()
})
return err
}
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive() iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout() iceDisconnectedTimeout := iceDisconnectedTimeout()
iceFailedTimeout := iceFailedTimeout() iceFailedTimeout := iceFailedTimeout()
@@ -61,7 +75,12 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
} }
return ice.NewAgent(agentConfig) agent, err := ice.NewAgent(agentConfig)
if err != nil {
return nil, err
}
return &ThreadSafeAgent{Agent: agent}, nil
} }
func GenerateICECredentials() (string, string, error) { func GenerateICECredentials() (string, string, error) {

View File

@@ -0,0 +1,47 @@
package peer
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
const sessionIDSize = 5
type ICESessionID string
// NewICESessionID generates a new session ID for distinguishing sessions
func NewICESessionID() (ICESessionID, error) {
b := make([]byte, sessionIDSize)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
return ICESessionID(hex.EncodeToString(b)), nil
}
func ICESessionIDFromBytes(b []byte) (ICESessionID, error) {
if len(b) != sessionIDSize {
return "", fmt.Errorf("invalid session ID length: %d", len(b))
}
return ICESessionID(hex.EncodeToString(b)), nil
}
// Bytes returns the raw bytes of the session ID for protobuf serialization
func (id ICESessionID) Bytes() ([]byte, error) {
if len(id) == 0 {
return nil, fmt.Errorf("ICE session ID is empty")
}
b, err := hex.DecodeString(string(id))
if err != nil {
return nil, fmt.Errorf("invalid ICE session ID encoding: %w", err)
}
if len(b) != sessionIDSize {
return nil, fmt.Errorf("invalid ICE session ID length: expected %d bytes, got %d", sessionIDSize, len(b))
}
return b, nil
}
func (id ICESessionID) String() string {
return string(id)
}

View File

@@ -2,6 +2,7 @@ package peer
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -45,6 +46,10 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer // SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
msg, err := signal.MarshalCredential( msg, err := signal.MarshalCredential(
s.wgPrivateKey, s.wgPrivateKey,
offerAnswer.WgListenPort, offerAnswer.WgListenPort,
@@ -56,13 +61,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
bodyType, bodyType,
offerAnswer.RosenpassPubKey, offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr, offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress) offerAnswer.RelaySrvAddress,
sessionIDBytes)
if err != nil { if err != nil {
return err return err
} }
err = s.signal.Send(msg) if err = s.signal.Send(msg); err != nil {
if err != nil {
return err return err
} }

View File

@@ -140,7 +140,7 @@ type RosenpassState struct {
// whether it's enabled, and the last error message encountered during probing. // whether it's enabled, and the last error message encountered during probing.
type NSGroupState struct { type NSGroupState struct {
ID string ID string
Servers []string Servers []netip.AddrPort
Domains []string Domains []string
Enabled bool Enabled bool
Error error Error error

View File

@@ -42,8 +42,18 @@ type WorkerICE struct {
statusRecorder *Status statusRecorder *Status
hasRelayOnLocally bool hasRelayOnLocally bool
agent *ice.Agent agent *icemaker.ThreadSafeAgent
muxAgent sync.Mutex agentDialerCancel context.CancelFunc
agentConnecting bool // while it is true, drop all incoming offers
lastSuccess time.Time // with this avoid the too frequent ICE agent recreation
// remoteSessionID represents the peer's session identifier from the latest remote offer.
remoteSessionID ICESessionID
// sessionID is used to track the current session ID of the ICE agent
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
muxAgent sync.Mutex
StunTurn []*stun.URI StunTurn []*stun.URI
@@ -57,6 +67,11 @@ type WorkerICE struct {
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
sessionID, err := NewICESessionID()
if err != nil {
return nil, err
}
w := &WorkerICE{ w := &WorkerICE{
ctx: ctx, ctx: ctx,
log: log, log: log,
@@ -67,6 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally, hasRelayOnLocally: hasRelayOnLocally,
lastKnownState: ice.ConnectionStateDisconnected, lastKnownState: ice.ConnectionStateDisconnected,
sessionID: sessionID,
} }
localUfrag, localPwd, err := icemaker.GenerateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@@ -79,15 +95,36 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
} }
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE") w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
if w.agent != nil { if w.agentConnecting {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
if w.agent != nil {
// backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
// todo consider to switch to Relay connection while establishing a new ICE connection
}
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
preferredCandidateTypes = icemaker.CandidateTypesP2P() preferredCandidateTypes = icemaker.CandidateTypesP2P()
@@ -96,36 +133,124 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
w.log.Debugf("recreate ICE agent") w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
w.sentExtraSrflx = false
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
w.muxAgent.Unlock() w.muxAgent.Unlock()
w.log.Debugf("gather candidates") go w.connect(dialerCtx, agent, remoteOfferAnswer)
err = w.agent.GatherCandidates() }
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err) // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) InProgress() bool {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.agentConnecting
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
if err := agent.OnCandidate(w.onICECandidate); err != nil {
return nil, err
}
if err := agent.OnConnectionStateChange(w.onConnectionStateChange(agent, dialerCancel)); err != nil {
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
return nil, err
}
if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) SessionID() ICESessionID {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.sessionID
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("gather candidates")
if err := agent.GatherCandidates(); err != nil {
w.log.Warnf("failed to gather candidates: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial") w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer) remoteConn, err := w.turnAgentDial(ctx, agent, remoteOfferAnswer)
if err != nil { if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err) w.log.Debugf("failed to dial the remote peer: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
w.log.Debugf("agent dial succeeded") w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair() pair, err := agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
@@ -152,114 +277,39 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on ICE conn is ready to use") w.log.Debugf("on ICE conn is ready to use")
go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. w.log.Infof("connection succeeded with offer session: %s", remoteOfferAnswer.SessionIDString())
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock() w.agentConnecting = false
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String()) w.lastSuccess = time.Now()
if w.agent == nil { if remoteOfferAnswer.SessionID != nil {
w.log.Warnf("ICE Agent is not initialized yet") w.remoteSessionID = *remoteOfferAnswer.SessionID
return
} }
w.muxAgent.Unlock()
if candidateViaRoutes(candidate, haRoutes) { // todo: the potential problem is a race between the onConnectionStateChange
return w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
} }
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
w.sentExtraSrflx = false
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel() cancel()
if w.agent == nil { if err := agent.Close(); err != nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.agent = nil
w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent {
w.agent = nil
w.agentConnecting = false
}
w.muxAgent.Unlock()
} }
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -331,6 +381,32 @@ func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidat
w.config.Key) w.config.Key)
} }
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agent, dialerCancel)
default:
return
}
}
}
func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true return true
@@ -338,12 +414,12 @@ func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool
return false return false
} }
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key isControlling := w.config.LocalKey > w.config.Key
if isControlling { if isControlling {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }
} }

Some files were not shown because too many files have changed in this diff Show More