Compare commits

..

8 Commits

Author SHA1 Message Date
bcmmbaga
72513d7522 Skip network map calculation when client serial matches current
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-01-27 23:03:43 +03:00
Zoltan Papp
a1f1bf1f19 Merge branch 'main' into feat/network-map-serial 2025-12-18 15:59:53 +01:00
Zoltan Papp
b5dec3df39 Track network serial in engine 2025-12-18 15:27:49 +01:00
Hakan Sariman
20f5f00635 [client] Add unit tests for engine synchronization and Info flag copying
- Introduced tests for the Engine's handleSync method to verify behavior when SkipNetworkMapUpdate is true and when NetworkMap is nil.
- Added a test for the Info struct to ensure correct copying of flag values from one instance to another, while preserving unrelated fields.
2025-10-17 10:03:07 +03:00
Hakan Sariman
fc141cf3a3 [client] Refactor lastNetworkMapSerial handling in GrpcClient
- Removed atomic operations for lastNetworkMapSerial and replaced them with mutex-based methods for thread-safe access.
2025-09-29 18:49:23 +07:00
Hakan Sariman
d0c65fa08e [client] Add skipNetworkMapUpdate field to SyncResponse for conditional updates 2025-09-29 18:28:14 +07:00
Hakan Sariman
f241bfa339 Refactor flag setting in Info struct to use CopyFlagsFrom method 2025-09-29 15:38:35 +07:00
Hakan Sariman
4b2cd97d5f [client] Enhance SyncRequest with NetworkMap serial tracking
- Added `networkMapSerial` field to `SyncRequest` for tracking the last known network map serial number.
- Updated `GrpcClient` to store and utilize the last network map serial during sync operations, optimizing synchronization processes.
- Improved handling of system info updates to ensure accurate metadata is sent with sync requests.
2025-09-25 19:28:35 +07:00
336 changed files with 3559 additions and 26956 deletions

View File

@@ -1,15 +1,15 @@
FROM golang:1.25-bookworm
FROM golang:1.23-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\
gettext-base=0.21-12 \
iptables=1.8.9-2 \
libgl1-mesa-dev=22.3.6-1+deb12u1 \
xorg-dev=1:7.7+23 \
libayatana-appindicator3-dev=0.5.92-1 \
gettext-base=0.21-4 \
iptables=1.8.7-1 \
libgl1-mesa-dev=20.3.5-1 \
xorg-dev=1:7.7+22 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@latest
&& go install -v golang.org/x/tools/gopls@v0.18.1
WORKDIR /app

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
@@ -39,7 +39,7 @@ jobs:
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...

View File

@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.25-alpine \
golang:1.24-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
-timeout 10m ./relay/... ./shared/relay/...
test_signal:
name: "Signal / Unit"

View File

@@ -52,10 +52,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@v4
with:
version: latest
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=12m
args: --timeout=12m --out-format colored-line-number

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.0"
SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,100 +19,6 @@ concurrency:
cancel-in-progress: true
jobs:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release:
runs-on: ubuntu-latest-m
env:

View File

@@ -243,7 +243,6 @@ jobs:
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose logs
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db

View File

@@ -14,9 +14,6 @@ jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
env:
GOOS: js
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -27,14 +24,16 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
working-directory: ./client
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:

1
.gitignore vendored
View File

@@ -31,4 +31,3 @@ infrastructure_files/setup-*.env
.DS_Store
vendor/
/netbird
client/netbird-electron/

View File

@@ -1,124 +1,139 @@
version: "2"
linters:
default: none
enable:
- bodyclose
- dupword
- durationcheck
- errcheck
- forbidigo
- gocritic
- gosec
- govet
- ineffassign
- mirror
- misspell
- nilerr
- nilnil
- predeclared
- revive
- sqlclosecheck
- staticcheck
- unused
- wastedassign
settings:
errcheck:
check-type-assertions: false
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
gosec:
includes:
- G101
- G103
- G104
- G106
- G108
- G109
- G110
- G111
- G201
- G202
- G203
- G301
- G302
- G303
- G304
- G305
- G306
- G307
- G403
- G502
- G503
- G504
- G601
- G602
govet:
enable:
- nilness
enable-all: false
revive:
rules:
- name: exported
arguments:
- checkPrivateReceivers
- sayRepetitiveInsteadOfStutters
severity: warning
disabled: false
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
run:
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
revive:
rules:
- linters:
- forbidigo
path: management/cmd/root\.go
- linters:
- forbidigo
path: signal/cmd/root\.go
- linters:
- unused
path: sharedsock/filter\.go
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- gosec
- mirror
path: test\.go
- linters:
- nilnil
path: mock\.go
- linters:
- staticcheck
text: grpc.DialContext is deprecated
- linters:
- staticcheck
text: grpc.WithBlock is deprecated
- linters:
- staticcheck
text: "QF1001"
- linters:
- staticcheck
text: "QF1008"
- linters:
- staticcheck
text: "QF1012"
paths:
- third_party$
- builtin$
- examples$
- name: exported
severity: warning
disabled: false
arguments:
- "checkPrivateReceivers"
- "sayRepetitiveInsteadOfStutters"
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
linters:
disable-all: true
enable:
## enabled by default
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- gosimple # specializes in simplifying a code
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- ineffassign # detects when assignments to existing variables are not used
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code
- unused # checks for unused constants, variables, functions and types
## disable by default but the have interesting results so lets add them
- bodyclose # checks whether HTTP response body is closed successfully
- dupword # dupword checks for duplicate words in the source code
- durationcheck # durationcheck checks for two durations multiplied together
- forbidigo # forbidigo forbids identifiers
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gosec # inspects source code for security problems
- mirror # mirror reports wrong mirror patterns of bytes/strings usage
- misspell # misspess finds commonly misspelled English words in comments
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$
exclude-rules:
# allow fmt
- path: management/cmd/root\.go
linters: forbidigo
- path: signal/cmd/root\.go
linters: forbidigo
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -713,10 +713,8 @@ checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh

View File

@@ -38,11 +38,6 @@
</strong>
<br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider
</a>
@@ -90,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- **Public domain** name pointing to the VM.
**Software requirements:**
@@ -103,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
@@ -118,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -59,6 +59,7 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
@@ -67,16 +68,18 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: platformFiles.ConfigurationFilePath(),
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -84,20 +87,15 @@ func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAd
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgFile,
ConfigPath: c.cfgFile,
})
if err != nil {
return err
@@ -124,22 +122,16 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgFile,
ConfigPath: c.cfgFile,
})
if err != nil {
return err
@@ -157,8 +149,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -1,257 +0,0 @@
//go:build android
package android
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
// Android-specific config filename (different from desktop default.json)
defaultConfigFilename = "netbird.cfg"
// Subdirectory for non-default profiles (must match Java Preferences.java)
profilesSubdir = "profiles"
// Android uses a single user context per app (non-empty username required by ServiceManager)
androidUsername = "android"
)
// Profile represents a profile for gomobile
type Profile struct {
Name string
IsActive bool
}
// ProfileArray wraps profiles for gomobile compatibility
type ProfileArray struct {
items []*Profile
}
// Length returns the number of profiles
func (p *ProfileArray) Length() int {
return len(p.items)
}
// Get returns the profile at index i
func (p *ProfileArray) Get(i int) *Profile {
if i < 0 || i >= len(p.items) {
return nil
}
return p.items[i]
}
/*
/data/data/io.netbird.client/files/ ← configDir parameter
├── netbird.cfg ← Default profile config
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
└── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
// It wraps the internal profilemanager to provide Android-specific behavior
type ProfileManager struct {
configDir string
serviceMgr *profilemanager.ServiceManager
}
// NewProfileManager creates a new profile manager for Android
func NewProfileManager(configDir string) *ProfileManager {
// Set the default config path for Android (stored in root configDir, not profiles/)
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
// Set global paths for Android
profilemanager.DefaultConfigPathDir = configDir
profilemanager.DefaultConfigPath = defaultConfigPath
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
// Create ServiceManager with profiles/ subdirectory
// This avoids modifying the global ConfigDirOverride for profile listing
profilesDir := filepath.Join(configDir, profilesSubdir)
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
return &ProfileManager{
configDir: configDir,
serviceMgr: serviceMgr,
}
}
// ListProfiles returns all available profiles
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
// Convert internal profiles to Android Profile type
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
Name: p.Name,
IsActive: p.IsActive,
})
}
return &ProfileArray{items: profiles}, nil
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
config, err := profilemanager.ReadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to read profile config: %w", err)
}
// Clear authentication by removing private key and SSH key
config.PrivateKey = ""
config.SSHKey = ""
// Save config using internal profilemanager
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -136,7 +136,6 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
//nolint
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}

View File

@@ -81,7 +81,6 @@ var loginCmd = &cobra.Command{
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
@@ -207,7 +206,6 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -1,4 +1,5 @@
//go:build pprof
// +build pprof
package cmd

View File

@@ -85,9 +85,6 @@ var (
// Execute executes the root command.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute()
}
@@ -390,7 +387,6 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -1,176 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -1,276 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

View File

@@ -1,21 +0,0 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -1,220 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -1,74 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -634,11 +634,7 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err
}
if err := validateDestinationPort(remoteAddr); err != nil {
return fmt.Errorf("invalid remote address: %w", err)
}
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -656,11 +652,7 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err
}
if err := validateDestinationPort(localAddr); err != nil {
return fmt.Errorf("invalid local address: %w", err)
}
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -671,35 +663,6 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil
}
// validateDestinationPort checks that the destination address has a valid port.
// Port 0 is only valid for bind addresses (where the OS picks an available port),
// not for destination addresses where we need to connect.
func validateDestinationPort(addr string) error {
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
return nil
}
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("parse address %s: %w", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port %s: %w", portStr, err)
}
if port == 0 {
return fmt.Errorf("port 0 is not valid for destination address")
}
if port < 0 || port > 65535 {
return fmt.Errorf("port %d out of range (1-65535)", port)
}
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {

View File

@@ -124,7 +124,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -89,6 +89,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
@@ -124,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r, false)
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
@@ -216,7 +216,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)

View File

@@ -1,13 +0,0 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -1,75 +0,0 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available

View File

@@ -386,8 +386,11 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if ip.IsUnspecified() {
matchByIP = false
}
if matchByIP {
if ipsetName != "" {

View File

@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule)
}
var denyRuleIndex, acceptRuleIndex = -1, -1
var denyRuleIndex, acceptRuleIndex int = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)

View File

@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex = -1, -1
var acceptRuleIndex, denyRuleIndex int = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
@@ -208,13 +208,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName {
case "accept-http":
if lookup.SetName == "accept-http" {
hasAcceptHTTPSet = true
case "deny-http":
} else if lookup.SetName == "deny-http" {
hasDenyHTTPSet = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
@@ -224,10 +222,9 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind {
case expr.VerdictAccept:
if verdict.Kind == expr.VerdictAccept {
action = "ACCEPT"
case expr.VerdictDrop:
} else if verdict.Kind == expr.VerdictDrop {
action = "DROP"
}
}
@@ -389,97 +386,6 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
const octet2Count = 25
const octet3Count = 255
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
for i := 1; i < octet2Count; i++ {
for j := 1; j < octet3Count; j++ {
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")

View File

@@ -48,12 +48,10 @@ const (
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
@@ -515,35 +513,16 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
}
elements := convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
var subEnd int
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd = min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}

View File

@@ -29,7 +29,7 @@ import (
)
const (
layerTypeAll = 255
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
@@ -262,7 +262,10 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix := iface.Address().Network
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering(
@@ -436,7 +439,19 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort
r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer)
switch proto {
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet
@@ -481,17 +496,16 @@ func (m *Manager) addRouteFiltering(
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort,
dstPort: dPort,
action: action,
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
@@ -781,7 +795,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum = pseudoSum
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
@@ -931,7 +945,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked {
pnum := getProtocolFromPacket(d)
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -996,22 +1010,20 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false
}
protoLayer := d.decoded[1]
proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: proto,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
@@ -1040,33 +1052,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true
}
func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return nftypes.TCP
return firewall.ProtocolTCP, nftypes.TCP
case layers.LayerTypeUDP:
return nftypes.UDP
return firewall.ProtocolUDP, nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return nftypes.ICMP
return firewall.ProtocolICMP, nftypes.ICMP
default:
return nftypes.ProtocolUnknown
return firewall.ProtocolALL, nftypes.ProtocolUnknown
}
}
@@ -1238,30 +1233,19 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
return nil, false
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
@@ -1280,8 +1264,21 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break
}
}
if !sourceMatched {
return false
}
return sourceMatched
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}
}

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)
})
}
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
}
})
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -2,7 +2,6 @@ package forwarder
import (
"fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -17,7 +16,7 @@ type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu atomic.Uint32
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -29,7 +28,7 @@ func (e *endpoint) IsAttached() bool {
}
func (e *endpoint) MTU() uint32 {
return e.mtu.Load()
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -83,22 +82,6 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}
func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}
func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}
type epID stack.TransportEndpointID
func (i epID) String() string {

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@@ -36,16 +35,14 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -63,8 +60,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
endpoint.mtu.Store(uint32(mtu))
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err)
@@ -106,16 +103,15 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
}
receiveWindow := defaultReceiveWindow
@@ -133,8 +129,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
@@ -204,24 +198,3 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort,
}
}
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -2,11 +2,8 @@ package forwarder
import (
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"runtime"
"time"
"github.com/google/uuid"
@@ -17,95 +14,30 @@ import (
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
// For Echo Requests, send and wait for response
if icmpHdr.Type() == header.ICMPv4Echo {
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
if !f.hasRawICMPAccess {
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return true
}
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
return nil, fmt.Errorf("create ICMP socket: %w", err)
}
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
}
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
@@ -113,22 +45,38 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
}
}()
txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
response := make([]byte, f.endpoint.mtu.Load())
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
@@ -137,7 +85,31 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
return 0
}
return f.injectICMPReply(id, response[:n])
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
}
// sendICMPEvent stores flow events for ICMP packets
@@ -180,95 +152,3 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
f.flowLogger.StoreEvent(fields)
}
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
// This is used as a fallback when raw socket access is not available.
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
// Returns the size of the injected packet.
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyICMPHdr := header.ICMPv4(replyICMP)
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
return f.injectICMPReply(id, replyICMP)
}
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
// Returns the total size of the injected packet, or 0 if injection failed.
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
return 0
}
return len(fullPacket)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
@@ -132,10 +131,10 @@ func (f *udpForwarder) cleanup() {
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return false
return
}
id := r.ID()
@@ -145,7 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return true
return
}
flowID := uuid.New()
@@ -163,7 +162,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return false
return
}
// Create wait queue for blocking syscalls
@@ -174,10 +173,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return false
return
}
inConn := gonet.NewUDPConn(&wq, ep)
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
@@ -200,7 +199,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return true
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
@@ -209,7 +208,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
return true
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
@@ -350,7 +348,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {

View File

@@ -130,7 +130,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}

View File

@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
_, _ = mapManager.localIPs[ip.String()]
}
})
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_ = mapManager.localIPs[ip.String()]
_, _ = mapManager.localIPs[ip.String()]
}
})
}

View File

@@ -168,15 +168,6 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
}
}
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {

View File

@@ -234,10 +234,9 @@ func TestInboundPortDNATNegative(t *testing.T) {
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
switch tc.protocol {
case layers.IPProtocolTCP:
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
case layers.IPProtocolUDP:
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix
dstSet firewall.Set
destinations []netip.Prefix
protoLayer gopacket.LayerType
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
protoLayer := d.decoded[1]
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {

View File

@@ -27,23 +27,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:

View File

@@ -1,7 +1,9 @@
//go:build ios
// +build ios
package device
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
@@ -43,31 +45,10 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
// This typically means the Swift code couldn't find the utun control socket.
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
var tunDevice tun.Device
var err error
// Validate the tunnel file descriptor.
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
// A value of 0 means the Swift code couldn't find the utun control socket
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
// tvOS SDK headers). This is a hard error - there's no viable fallback
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
if t.tunFd == 0 {
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
return nil, ErrInvalidTunnelFD
}
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
var dupTunFd int
dupTunFd, err = unix.Dup(t.tunFd)
dupTunFd, err := unix.Dup(t.tunFd)
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return nil, err
@@ -79,7 +60,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
_ = unix.Close(dupTunFd)
return nil, err
}
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd)

View File

@@ -3,19 +3,12 @@
package wgproxy
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
const (
envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
)
type KernelFactory struct {
wgPort int
mtu uint16
@@ -29,12 +22,6 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
mtu: mtu,
}
if isEBPFDisabled() {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
return f
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
@@ -60,16 +47,3 @@ func (w *KernelFactory) Free() error {
}
return w.ebpfProxy.Free()
}
func isEBPFDisabled() bool {
val := os.Getenv(envDisableEBPFWGProxy)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
return false
}
return disabled
}

View File

@@ -24,14 +24,10 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -43,13 +39,11 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -58,15 +52,13 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
@@ -170,33 +162,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
var path string
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
// On mobile, use the provided state file path directly
if !fileExists(mobileDependency.StateFilePath) {
if err := createFile(mobileDependency.StateFilePath); err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDependency.StateFilePath
} else {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -308,7 +273,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -318,15 +283,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)

View File

@@ -27,7 +27,6 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -57,7 +56,6 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -111,9 +109,6 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -332,10 +327,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -363,10 +354,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
log.Errorf("failed to add updater logs: %v", err)
}
return nil
}
@@ -535,18 +522,6 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
@@ -655,29 +630,6 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
func (g *BundleGenerator) addUpdateLogs() error {
inst := installer.New()
logFiles := inst.LogFiles()
if len(logFiles) == 0 {
return nil
}
log.Infof("adding updater logs")
for _, logFile := range logFiles {
data, err := os.ReadFile(logFile)
if err != nil {
log.Warnf("failed to read update log file %s: %v", logFile, err)
continue
}
baseName := filepath.Base(logFile)
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
}
}
return nil
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()

View File

@@ -507,13 +507,15 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
if p.Base == expr.PayloadBaseNetworkHeader {
switch p.Offset {
case 12:
switch p.Len {
case 4, 2:
if p.Len == 4 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
case 16:
switch p.Len {
case 4, 2:
if p.Len == 4 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} else if p.Len == 2 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
}
}

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.NonAuthoritative {
if zone.SkipPTRProcess {
continue
}
for _, record := range zone.Records {

View File

@@ -3,15 +3,11 @@ package dns
import (
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
)
const (
@@ -47,23 +43,7 @@ type HandlerChain struct {
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
requestID string
shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -72,7 +52,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true
return nil
}
w.response = m
return w.ResponseWriter.WriteMsg(m)
}
@@ -122,8 +101,6 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -163,109 +140,68 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break
}
}
}
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
qname := strings.ToLower(r.Question[0].Name)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order
for _, entry := range handlers {
if !c.isHandlerMatch(qname, entry) {
continue
}
matched := c.isHandlerMatch(qname, entry)
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
continue
}
entry.Handler.ServeDNS(chainWriter, r)
c.logResponse(logger, chainWriter, qname, startTime)
return
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
}
}
// No handler matched or all handlers passed
logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
log.Tracef("no handler found for domain=%s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
}
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":

View File

@@ -1,52 +1,30 @@
package local
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
}
func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
}
@@ -59,18 +37,7 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
@@ -81,85 +48,35 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 {
logger.Debug("received local resolver request with no question")
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
result := d.lookupRecords(logger, question)
replyMessage.Authoritative = !result.hasExternalData
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
}
if err := w.WriteMsg(replyMessage); err != nil {
logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
log.Warnf("failed to write the local resolver response: %v", err)
}
}
@@ -172,27 +89,8 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
return exists
}
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
@@ -200,14 +98,10 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
cnameQuestion := dns.Question{
Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return lookupResult{rcode: dns.RcodeNameError}
return nil
}
recordsCopy := slices.Clone(records)
@@ -225,178 +119,20 @@ func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) looku
d.mu.Unlock()
}
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
return recordsCopy
}
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName)) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}

View File

@@ -1,14 +1,8 @@
package local
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -18,18 +12,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
// mockResolver implements resolver for testing
type mockResolver struct {
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
}
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
if m.lookupFunc != nil {
return m.lookupFunc(ctx, network, host)
}
return nil, nil
}
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
@@ -124,11 +106,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
resolver := NewResolver()
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(zone1)
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
@@ -140,7 +122,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(zone2)
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
@@ -169,10 +151,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(zones)
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
@@ -213,10 +195,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(zones)
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
@@ -282,7 +264,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
}
// Update resolver with the records
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
@@ -397,7 +379,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
}
// Update resolver with both records
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
@@ -494,20 +476,6 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
// with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver()
// Mock external resolver for CNAME target resolution
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "target.example.com." {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
if network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.",
@@ -525,7 +493,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
RData: "target.example.com.",
}
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
testCases := []struct {
name string
@@ -614,808 +582,3 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
})
}
}
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("simple internal CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "target.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "192.168.1.1", a.A.String())
})
t.Run("multi-hop CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3)
})
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok)
})
}
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
t.Run("chain at max depth resolves", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 7 CNAMEs (under max of 8)
for i := 1; i <= 7; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("hop%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("hop%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 8)
})
t.Run("chain exceeding max depth stops", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 10 CNAMEs (exceeds max of 8)
for i := 1; i <= 10; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("deep%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("deep%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
// Should NOT have the final A record (chain too deep)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
},
}})
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
}
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "93.184.216.34", a.A.String())
})
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
aaaa, ok := resp.Answer[1].(*dns.AAAA)
require.True(t, ok)
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
})
t.Run("concurrent external resolution", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
var wg sync.WaitGroup
results := make([]*dns.Msg, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
results[idx] = resp
}(i)
}
wg.Wait()
for i, resp := range results {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
}
})
}
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update sets zones correctly", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}},
{Domain: "test.local."},
})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("other.example.com."))
assert.True(t, resolver.isInManagedZone("sub.test.local."))
assert.False(t, resolver.isInManagedZone("external.com."))
})
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
})
t.Run("Update clears zones", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
resolver.Update(nil)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
}
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
})
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.other.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
})
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
resolver := NewResolver()
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." {
if network == "ip6" {
// No AAAA records
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
if network == "ip4" {
// But A records exist - domain exists
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
// Table-driven test for all external resolution outcomes
externalCases := []struct {
name string
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
expectedRcode int
expectedAnswer int
}{
{
name: "external NXDOMAIN (both A and AAAA not found)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeNameError,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (temporary error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTemporary: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (timeout)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTimeout: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (generic error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, fmt.Errorf("connection refused")
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external success with IPs",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeSuccess,
expectedAnswer: 2, // CNAME + A
},
}
for _, tc := range externalCases {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
if tc.expectedAnswer > 0 {
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "first answer should be CNAME")
}
})
}
}
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
// trigger fallthrough (Zero bit set) when no records match
func TestLocalResolver_Fallthrough(t *testing.T) {
resolver := NewResolver()
record := nbdns.SimpleRecord{
Name: "existing.custom.zone.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}
testCases := []struct {
name string
zones []nbdns.CustomZone
queryName string
expectFallthrough bool
expectRecord bool
}{
{
name: "Authoritative zone returns NXDOMAIN without fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: false,
expectRecord: false,
},
{
name: "Non-authoritative zone triggers fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: true,
expectRecord: false,
},
{
name: "Record found in non-authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
{
name: "Record found in authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resolver.Update(tc.zones)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response")
if tc.expectFallthrough {
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
} else {
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
}
if tc.expectRecord {
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
}
})
}
}
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
t.Run("direct record lookup is authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.True(t, resp.Authoritative)
})
t.Run("external resolution is not authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2)
assert.False(t, resp.Authoritative)
})
}
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Len(t, resp.Answer, 0)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
resolver.Stop()
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})
lookupCtxCanceled := make(chan struct{})
resolver.resolver = &mockResolver{
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
close(lookupStarted)
<-ctx.Done()
close(lookupCtxCanceled)
return nil, ctx.Err()
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
done := make(chan struct{})
go func() {
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
close(done)
}()
<-lookupStarted
resolver.Stop()
select {
case <-lookupCtxCanceled:
case <-time.After(time.Second):
t.Fatal("external lookup context was not canceled")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("ServeDNS did not return after Stop")
}
})
}
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "EXAMPLE.COM.",
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
NonAuthoritative: true,
}})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG)
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
}
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
func BenchmarkFindZone_BestCase(b *testing.B) {
resolver := NewResolver()
// Single zone that matches immediately
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
NonAuthoritative: true,
}})
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough("example.com.")
}
}
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
func BenchmarkFindZone_WorstCase(b *testing.B) {
resolver := NewResolver()
// 100 zones that won't match
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
NonAuthoritative: true,
})
}
resolver.Update(zones)
// Query with many labels that won't match any zone
qname := "a.b.c.d.e.f.g.h.external.com."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
func BenchmarkFindZone_TypicalCase(b *testing.B) {
resolver := NewResolver()
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
resolver.Update([]nbdns.CustomZone{
{Domain: "netbird.cloud.", NonAuthoritative: false},
{Domain: "custom.local.", NonAuthoritative: true},
})
// Query for subdomain of user zone
qname := "myhost.custom.local."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver := NewResolver()
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
})
}
resolver.Update(zones)
// Query that matches zone50
qname := "host.zone50.internal."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.isInManagedZone(qname)
}
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
@@ -27,11 +26,6 @@ type Resolver struct {
mutex sync.RWMutex
}
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
@@ -105,9 +99,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := lookupIPWithExtraTimeout(ctx, d)
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return err
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
var aRecords, aaaaRecords []dns.RR
@@ -165,36 +159,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {

View File

@@ -1,197 +0,0 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -80,7 +80,6 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -208,7 +207,6 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -485,7 +483,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -498,7 +496,8 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
s.localResolver.Update(localZones)
// register local records
s.localResolver.Update(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -587,29 +586,8 @@ func (s *DefaultServer) applyHostConfig() {
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
// Fall through to apply config anyway (fail-safe approach)
} else if s.currentConfigHash == hash {
log.Debugf("not applying host config as there are no changes")
return
}
log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
return
}
// Only update hash if it was computed successfully and config was applied
if err == nil {
s.currentConfigHash = hash
}
s.registerFallback(config)
@@ -658,9 +636,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
var muxUpdates []handlerWrapper
var zones []nbdns.CustomZone
var localRecords []nbdns.SimpleRecord
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -674,20 +652,17 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, zones, nil
return muxUpdates, localRecords, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {

View File

@@ -128,7 +128,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalZones []nbdns.CustomZone
initLocalRecords []nbdns.SimpleRecord
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -181,7 +181,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -222,7 +222,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
@@ -230,7 +230,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -252,7 +252,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -274,7 +274,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initLocalRecords: []nbdns.SimpleRecord{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -300,7 +300,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -316,7 +316,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.localResolver.Update(testCase.initLocalRecords)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,7 +510,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -1601,10 +1602,7 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
// the domain remains in the config (ref count goes from 2 to 1), so the host
// config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 3,
applyHostConfigCall: 4,
},
{
name: "Config update with new domains after registration",
@@ -1659,10 +1657,7 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
// it's removed from extraDomains but still remains in the config (from customZones),
// so the host config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 2,
applyHostConfigCall: 3,
},
{
name: "Register domain that is part of nameserver group",

View File

@@ -2,6 +2,7 @@ package dns
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
@@ -20,7 +21,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
@@ -113,7 +113,10 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
@@ -199,18 +202,11 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
@@ -418,6 +414,16 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -18,7 +18,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
@@ -190,22 +189,29 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
@@ -215,35 +221,33 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
return nil
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
@@ -261,33 +265,19 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
log.Errorf("failed to write DNS response: %v", err)
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
resp := f.handleDNSQuery(w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
log.Errorf("failed to write DNS response: %v", err)
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -325,64 +315,140 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
}
}
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
result resutil.LookupResult,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
resp.Rcode = result.Rcode
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.cache.set(domain, question.Qtype, nil)
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
return
}
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
log.Warnf(errResolveFailed, domain, err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
}

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -318,7 +317,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
@@ -466,7 +465,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
@@ -528,7 +527,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
@@ -605,7 +604,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -675,7 +674,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -685,7 +684,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -715,7 +714,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -729,7 +728,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -784,7 +783,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -905,7 +904,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
@@ -938,7 +937,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -42,13 +42,14 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -72,7 +73,6 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -201,9 +201,6 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -224,7 +221,17 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -240,12 +247,28 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -285,10 +308,6 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
@@ -522,13 +541,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -737,41 +749,6 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -781,10 +758,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -824,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
nm := update.GetNetworkMap()
if nm == nil {
if nm == nil || update.SkipNetworkMapUpdate {
return nil
}
@@ -990,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1121,15 +1094,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
@@ -1138,34 +1102,32 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err
}
} else {
err := e.removePeers(remotePeers)
err := e.removePeers(networkMap.GetRemotePeers())
if err != nil {
return err
}
err = e.modifyPeers(remotePeers)
err = e.modifyPeers(networkMap.GetRemotePeers())
if err != nil {
return err
}
err = e.addNewPeers(remotePeers)
err = e.addNewPeers(networkMap.GetRemotePeers())
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
if err := e.updateSSHClientConfig(remotePeers); err != nil {
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial
@@ -1251,16 +1213,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort,
}
protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
SkipPTRProcess: zone.GetSkipPTRProcess(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{

View File

@@ -11,18 +11,15 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
UpdateSSHAuth(config *sshauth.Config)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -356,38 +353,3 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus()
}
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil {
return
}
if e.sshServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
// Update SSH server with new authorization configuration
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.sshServer.UpdateSSHAuth(authConfig)
}

View File

@@ -0,0 +1,79 @@
package internal
import (
"context"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun199",
WgAddr: "100.70.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
// Precondition
if engine.networkSerial != 0 {
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
}
resp := &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
SkipNetworkMapUpdate: true,
}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
if engine.networkSerial != 0 {
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
}
}
// Ensures handleSync exits early when NetworkMap is nil
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun198",
WgAddr: "100.70.0.2/24",
WgPrivateKey: key,
WgPort: 33101,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
}

View File

@@ -253,7 +253,6 @@ func TestEngine_SSH(t *testing.T) {
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -415,13 +414,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -624,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -640,7 +647,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -805,7 +812,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1007,7 +1014,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1533,7 +1540,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}
@@ -1631,7 +1638,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -1,4 +1,5 @@
//go:build !windows
// +build !windows
package internal

View File

@@ -110,6 +110,7 @@ func wakeUpListen(ctx context.Context) {
}
if newHash == initialHash {
log.Tracef("no wakeup detected")
continue
}

View File

@@ -148,15 +148,13 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.semaphore.Add(engineCtx)
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
conn.semaphore.Done()
conn.semaphore.Done(engineCtx)
return nil
}
@@ -167,7 +165,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -203,7 +200,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done()
conn.semaphore.Done(conn.ctx)
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
@@ -669,17 +666,10 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects cancelFunc
// mu protects updateWireGuardPeer and cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,9 +86,11 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -3,11 +3,9 @@ package profilemanager
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
@@ -167,26 +165,19 @@ func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
base, err := baseConfigDir()
configDir, err := os.UserConfigDir()
if err != nil {
return "", err
}
configDir := filepath.Join(base, "netbird")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
func baseConfigDir() (string, error) {
if runtime.GOOS == "darwin" {
if u, err := user.Current(); err == nil && u.HomeDir != "" {
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
configDir = filepath.Join(configDir, "netbird")
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
}
}
return os.UserConfigDir()
return configDir, nil
}
func getConfigDirForUser(username string) (string, error) {
@@ -685,7 +676,7 @@ func update(input ConfigInput) (*Config, error) {
return config, nil
}
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
// GetConfig read config file and return with Config. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) {
return readConfig(configPath, false)
}
@@ -821,85 +812,3 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) {
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
}
// DirectWriteOutConfig writes config directly without atomic temp file operations.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectWriteOutConfig(path string, config *Config) error {
return util.DirectWriteJson(context.Background(), path, config)
}
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
return nil, err
}
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
if err := util.EnforcePermission(input.ConfigPath); err != nil {
log.Errorf("failed to enforce permission on config file: %v", err)
}
return directUpdate(input)
}
func directUpdate(input ConfigInput) (*Config, error) {
config := &Config{}
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}
updated, err := config.apply(input)
if err != nil {
return nil, err
}
if updated {
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
return config, nil
}
// ConfigToJSON serializes a Config struct to a JSON string.
// This is useful for exporting config to alternative storage mechanisms
// (e.g., UserDefaults on tvOS where file writes are blocked).
func ConfigToJSON(config *Config) (string, error) {
bs, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", err
}
return string(bs), nil
}
// ConfigFromJSON deserializes a JSON string to a Config struct.
// This is useful for restoring config from alternative storage mechanisms.
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
func ConfigFromJSON(jsonStr string) (*Config, error) {
config := &Config{}
err := json.Unmarshal([]byte(jsonStr), config)
if err != nil {
return nil, err
}
// Apply defaults to ensure required fields are initialized.
// This mirrors what readConfig does after loading from file.
if _, err := config.apply(ConfigInput{}); err != nil {
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
}
return config, nil
}

View File

@@ -76,7 +76,6 @@ func (a *ActiveProfileState) FilePath() (string, error) {
}
type ServiceManager struct {
profilesDir string // If set, overrides ConfigDirOverride for profile operations
}
func NewServiceManager(defaultConfigPath string) *ServiceManager {
@@ -86,17 +85,6 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
return &ServiceManager{}
}
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
// This allows setting the profiles directory without modifying the global ConfigDirOverride
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
if defaultConfigPath != "" {
DefaultConfigPath = defaultConfigPath
}
return &ServiceManager{
profilesDir: profilesDir,
}
}
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
@@ -126,6 +114,14 @@ func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
log.Warnf("failed to set permissions for default profile: %v", err)
}
if err := s.SetActiveProfileState(&ActiveProfileState{
Name: "default",
Username: "",
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return false, fmt.Errorf("failed to set active profile state: %w", err)
}
return true, nil
}
@@ -244,7 +240,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
configDir, err := getConfigDirForUser(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -274,7 +270,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
configDir, err := getConfigDirForUser(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -306,7 +302,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
}
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := s.getConfigDir(username)
configDir, err := getConfigDirForUser(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
@@ -365,7 +361,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
configDir, err := s.getConfigDir(activeProf.Username)
configDir, err := getConfigDirForUser(activeProf.Username)
if err != nil {
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
return defaultStatePath
@@ -373,12 +369,3 @@ func (s *ServiceManager) GetStatePath() string {
return filepath.Join(configDir, activeProf.Name+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
func (s *ServiceManager) getConfigDir(username string) (string, error) {
if s.profilesDir != "" {
return s.profilesDir, nil
}
return getConfigDirForUser(username)
}

View File

@@ -19,7 +19,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -220,14 +219,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{
"request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 {
return
}
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -281,10 +280,15 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
resutil.SetMeta(w, "peer", peerKey)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id
if err := d.writeMsg(w, reply, logger); err != nil {
if err := d.writeMsg(w, reply); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
}
@@ -320,15 +324,11 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
@@ -350,14 +350,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
@@ -370,11 +370,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
logger.Errorf("failed to update domain prefixes: %v", err)
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
d.replaceIPsInDNSResponse(r, newPrefixes)
}
}
@@ -386,22 +386,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
}
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
if len(toAdd) > 0 {
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
@@ -418,9 +418,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
@@ -432,7 +432,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.addDNATMappings(dnatMappings, logger)
d.addDNATMappings(dnatMappings)
if !d.route.KeepRoute {
// Remove old prefixes
@@ -448,7 +448,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
}
}
d.removeDNATMappings(toRemove, logger)
d.removeDNATMappings(toRemove)
}
// Update domain prefixes using resolved domain as key - store real IPs
@@ -463,14 +463,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
}
return nberrors.FormatErrorOrNil(merr)
}
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
if len(realPrefixes) == 0 {
return
}
@@ -484,9 +484,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
}
}
}
@@ -502,7 +502,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
if len(mappings) == 0 {
return
}
@@ -514,9 +514,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, log
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
@@ -528,12 +528,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
d.removeDNATMappings(prefixes)
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
if _, ok := d.internalDnatFw(); !ok {
return
}
@@ -549,7 +549,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
@@ -560,7 +560,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}

View File

@@ -1,4 +1,5 @@
//go:build !windows
// +build !windows
package iface

View File

@@ -210,8 +210,7 @@ func (r *SysOps) refreshLocalSubnetsCache() {
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
switch prefix {
case vars.Defaultv4:
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
@@ -234,7 +233,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
}
return nil
case vars.Defaultv6:
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
@@ -256,8 +255,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
switch prefix {
case vars.Defaultv4:
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
@@ -275,7 +273,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
return nberrors.FormatErrorOrNil(result)
case vars.Defaultv6:
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
@@ -285,9 +283,9 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
return nberrors.FormatErrorOrNil(result)
default:
return r.removeFromRouteTable(prefix, nextHop)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {

View File

@@ -1,35 +0,0 @@
// Package updatemanager provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
// # Overview
//
// The update manager operates as a background service that continuously monitors for
// available updates and automatically initiates the update process when conditions are met.
// It integrates with the installer package to perform the actual installation.
//
// # Update Flow
//
// The complete update process follows these steps:
//
// 1. Manager receives update directive via SetVersion() or detects new version
// 2. Manager validates update should proceed (version comparison, rate limiting)
// 3. Manager publishes "updating" event to status recorder
// 4. Manager persists UpdateState to track update attempt
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
// 6. Manager triggers installation via installer.RunInstallation()
// 7. Installer package handles the actual installation process
// 8. On next startup, CheckUpdateSuccess() verifies update completion
// 9. Manager publishes success/failure event to status recorder
// 10. Manager cleans up UpdateState
//
// # State Management
//
// Update state is persisted across restarts to track update attempts:
//
// - PreUpdateVersion: Version before update attempt
// - TargetVersion: Version attempting to update to
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updatemanager

View File

@@ -1,138 +0,0 @@
package downloader
import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
const (
userAgent = "NetBird agent installer/%s"
DefaultRetryDelay = 3 * time.Second
)
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
log.Debugf("starting download from %s", url)
out, err := os.Create(dstFile)
if err != nil {
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
}
defer func() {
if cerr := out.Close(); cerr != nil {
log.Warnf("error closing file %q: %v", dstFile, cerr)
}
}()
// First attempt
err = downloadToFileOnce(ctx, url, out)
if err == nil {
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
// If retryDelay is 0, don't retry
if retryDelay == 0 {
return err
}
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
// Sleep before retry
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
}
// Truncate file before retry
if err := out.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate file on retry: %w", err)
}
if _, err := out.Seek(0, 0); err != nil {
return fmt.Errorf("failed to seek to beginning of file: %w", err)
}
// Second attempt
if err := downloadToFileOnce(ctx, url, out); err != nil {
return fmt.Errorf("download failed after retry: %w", err)
}
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return data, nil
}
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("failed to write response body to file: %w", err)
}
return nil
}
func sleepWithContext(ctx context.Context, duration time.Duration) error {
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -1,199 +0,0 @@
package downloader
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
const (
retryDelay = 100 * time.Millisecond
)
func TestDownloadToFile_Success(t *testing.T) {
// Create a test server that responds successfully
content := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
}
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
content := "test file content after retry"
var attemptCount atomic.Int32
// Create a test server that fails on first attempt, succeeds on second
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := attemptCount.Add(1)
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should succeed after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
t.Fatalf("expected no error after retry, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
// Verify it took 2 attempts
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should fail after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
t.Fatal("expected error after retry, got nil")
}
// Verify it tried 2 times
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Create a context that will be cancelled during retry delay
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay (during the retry sleep)
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
// Download the file (should fail due to context cancellation during retry)
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
if err == nil {
t.Fatal("expected error due to context cancellation, got nil")
}
// Should have only made 1 attempt (cancelled during retry delay)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_InvalidURL(t *testing.T) {
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
if err == nil {
t.Fatal("expected error for invalid URL, got nil")
}
}
func TestDownloadToFile_InvalidDestination(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test"))
}))
defer server.Close()
// Use an invalid destination path
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
if err == nil {
t.Fatal("expected error for invalid destination, got nil")
}
}
func TestDownloadToFile_NoRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file with retryDelay = 0 (should not retry)
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
t.Fatal("expected error, got nil")
}
// Verify it only made 1 attempt (no retry)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}

View File

@@ -1,7 +0,0 @@
//go:build !windows
package installer
func UpdaterBinaryNameWithoutExtension() string {
return updaterBinary
}

View File

@@ -1,11 +0,0 @@
package installer
import (
"path/filepath"
"strings"
)
func UpdaterBinaryNameWithoutExtension() string {
ext := filepath.Ext(updaterBinary)
return strings.TrimSuffix(updaterBinary, ext)
}

View File

@@ -1,111 +0,0 @@
// Package installer provides functionality for managing NetBird application
// updates and installations across Windows, macOS. It handles
// the complete update lifecycle including artifact download, cryptographic verification,
// installation execution, process management, and result reporting.
//
// # Architecture
//
// The installer package uses a two-process architecture to enable self-updates:
//
// 1. Service Process: The main NetBird daemon process that initiates updates
// 2. Updater Process: A detached child process that performs the actual installation
//
// This separation is critical because:
// - The service binary cannot update itself while running
// - The installer (EXE/MSI/PKG) will terminate the service during installation
// - The updater process survives service termination and restarts it after installation
// - Results can be communicated back to the service after it restarts
//
// # Update Flow
//
// Service Process (RunInstallation):
//
// 1. Validates target version format (semver)
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
// 3. Downloads installer file from GitHub releases (if applicable)
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
// launching updater)
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
// 6. Launches updater process with detached mode:
// - --temp-dir: Temporary directory path
// - --service-dir: Service installation directory
// - --installer-file: Path to downloaded installer (if applicable)
// - --dry-run: Optional flag to test without actually installing
// 7. Service process continues running (will be terminated by installer later)
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
//
// Updater Process (Setup):
//
// 1. Receives parameters from service via command-line arguments
// 2. Runs installer with appropriate silent/quiet flags:
// - Windows EXE: installer.exe /S
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
// - macOS PKG: installer -pkg installer.pkg -target /
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
// 3. Installer terminates daemon and UI processes
// 4. Installer replaces binaries with new version
// 5. Updater waits for installer to complete
// 6. Updater restarts daemon:
// - Windows: netbird.exe service start
// - macOS/Linux: netbird service start
// 7. Updater restarts UI:
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
// - Linux: Not implemented (UI typically auto-starts)
// 8. Updater writes result.json with success/error status
// 9. Updater process exits
//
// # Result Communication
//
// The ResultHandler (result.go) manages communication between updater and service:
//
// Result Structure:
//
// type Result struct {
// Success bool // true if installation succeeded
// Error string // error message if Success is false
// ExecutedAt time.Time // when installation completed
// }
//
// Result files are automatically cleaned up after being read.
//
// # File Locations
//
// Temporary Directory (platform-specific):
//
// Windows:
// - Path: %ProgramData%\Netbird\tmp-install
// - Example: C:\ProgramData\Netbird\tmp-install
//
// macOS:
// - Path: /var/lib/netbird/tmp-install
// - Requires root permissions
//
// Files created during installation:
//
// tmp-install/
// installer.log
// updater[.exe] # Copy of service binary
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
// result.json # Installation result
// msi.log # MSI verbose log (Windows MSI only)
//
// # API Reference
//
// # Cleanup
//
// CleanUpInstallerFiles() removes temporary files after successful installation:
// - Downloaded installer files (*.exe, *.msi, *.pkg)
// - Updater binary copy
// - Does NOT remove result.json (cleaned by ResultHandler after read)
// - Does NOT remove msi.log (kept for debugging)
//
// # Dry-Run Mode
//
// Dry-run mode allows testing the update process without actually installing:
//
// Enable via environment variable:
//
// export NB_AUTO_UPDATE_DRY_RUN=true
// netbird service install-update 0.29.0
package installer

View File

@@ -1,50 +0,0 @@
//go:build !windows && !darwin
package installer
import (
"context"
"fmt"
)
const (
updaterBinary = "updater"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
func (u *Installer) TempDir() string {
return ""
}
func (c *Installer) LogFiles() []string {
return []string{}
}
func (u *Installer) CleanUpInstallerFiles() error {
return nil
}
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
return fmt.Errorf("unsupported platform")
}
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
return fmt.Errorf("unsupported platform")
}

View File

@@ -1,293 +0,0 @@
//go:build windows || darwin
package installer
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"github.com/hashicorp/go-multierror"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{
tempDir: defaultTempDir,
}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
// RunInstallation starts the updater process to run the installation
// This will run by the original service process
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
resultHandler := NewResultHandler(u.tempDir)
defer func() {
if err != nil {
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
log.Errorf("failed to write error result: %v", writeErr)
}
}
}()
if err := validateTargetVersion(targetVersion); err != nil {
return err
}
if err := u.mkTempDir(); err != nil {
return err
}
var installerFile string
// Download files only when not using any third-party store
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
log.Infof("download installer")
var err error
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
if err != nil {
log.Errorf("failed to download installer: %v", err)
return err
}
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
if err != nil {
log.Errorf("failed to create artifact verify: %v", err)
return err
}
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
log.Errorf("artifact verification error: %v", err)
return err
}
}
log.Infof("running installer")
updaterPath, err := u.copyUpdater()
if err != nil {
return err
}
// the directory where the service has been installed
workspace, err := getServiceDir()
if err != nil {
return err
}
args := []string{
"--temp-dir", u.tempDir,
"--service-dir", workspace,
}
if isDryRunEnabled() {
args = append(args, "--dry-run=true")
}
if installerFile != "" {
args = append(args, "--installer-file", installerFile)
}
updateCmd := exec.Command(updaterPath, args...)
log.Infof("starting updater process: %s", updateCmd.String())
// Configure the updater to run in a separate session/process group
// so it survives the parent daemon being stopped
setUpdaterProcAttr(updateCmd)
// Start the updater process asynchronously
if err := updateCmd.Start(); err != nil {
return err
}
pid := updateCmd.Process.Pid
log.Infof("updater started with PID %d", pid)
// Release the process so the OS can fully detach it
if err := updateCmd.Process.Release(); err != nil {
log.Warnf("failed to release updater process: %v", err)
}
return nil
}
// CleanUpInstallerFiles
// - the installer file (pkg, exe, msi)
// - the selfcopy updater.exe
func (u *Installer) CleanUpInstallerFiles() error {
// Check if tempDir exists
info, err := os.Stat(u.tempDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
var merr *multierror.Error
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
}
entries, err := os.ReadDir(u.tempDir)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
for _, ext := range binaryExtensions {
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
}
break
}
}
}
return merr.ErrorOrNil()
}
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
fileURL := urlWithVersionArch(installerType, targetVersion)
// Clean up temp directory on error
var success bool
defer func() {
if !success {
if err := os.RemoveAll(u.tempDir); err != nil {
log.Errorf("error cleaning up temporary directory: %v", err)
}
}
}()
fileName := path.Base(fileURL)
if fileName == "." || fileName == "/" || fileName == "" {
return "", fmt.Errorf("invalid file URL: %s", fileURL)
}
outputFilePath := filepath.Join(u.tempDir, fileName)
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
return "", err
}
success = true
return outputFilePath, nil
}
func (u *Installer) TempDir() string {
return u.tempDir
}
func (u *Installer) mkTempDir() error {
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
log.Debugf("failed to create tempdir: %s", u.tempDir)
return err
}
return nil
}
func (u *Installer) copyUpdater() (string, error) {
src, err := getServiceBinary()
if err != nil {
return "", fmt.Errorf("failed to get updater binary: %w", err)
}
dst := filepath.Join(u.tempDir, updaterBinary)
if err := copyFile(src, dst); err != nil {
return "", fmt.Errorf("failed to copy updater binary: %w", err)
}
if err := os.Chmod(dst, 0o755); err != nil {
return "", fmt.Errorf("failed to set permissions: %w", err)
}
return dst, nil
}
func validateTargetVersion(targetVersion string) error {
if targetVersion == "" {
return fmt.Errorf("target version cannot be empty")
}
_, err := goversion.NewVersion(targetVersion)
if err != nil {
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
}
return nil
}
func copyFile(src, dst string) error {
log.Infof("copying %s to %s", src, dst)
in, err := os.Open(src)
if err != nil {
return fmt.Errorf("open source: %w", err)
}
defer func() {
if err := in.Close(); err != nil {
log.Warnf("failed to close source file: %v", err)
}
}()
out, err := os.Create(dst)
if err != nil {
return fmt.Errorf("create destination: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Warnf("failed to close destination file: %v", err)
}
}()
if _, err := io.Copy(out, in); err != nil {
return fmt.Errorf("copy: %w", err)
}
return nil
}
func getServiceDir() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exePath), nil
}
func getServiceBinary() (string, error) {
return os.Executable()
}
func isDryRunEnabled() bool {
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
}

View File

@@ -1,11 +0,0 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -1,12 +0,0 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, msiLogFile),
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -1,238 +0,0 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const (
daemonName = "netbird"
updaterBinary = "updater"
uiBinary = "/Applications/NetBird.app"
defaultTempDir = "/var/lib/netbird/tmp-install"
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
var (
binaryExtensions = []string{"pkg"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
// skip service restart if dry-run mode is enabled
if dryRun {
return
}
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(); err != nil {
log.Errorf("failed to start UI: %v", err)
}
}()
if dryRun {
time.Sleep(7 * time.Second)
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
switch TypeOfInstaller(ctx) {
case TypePKG:
resultErr = u.installPkgFile(ctx, installerFile)
case TypeHomebrew:
resultErr = u.updateHomeBrew(ctx)
}
return resultErr
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser() error {
log.Infof("starting netbird-ui: %s", uiBinary)
// Get the current console user
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("failed to get console user: %w", err)
}
username := strings.TrimSpace(string(output))
if username == "" || username == "root" {
return fmt.Errorf("no active user session found")
}
log.Infof("starting UI for user: %s", username)
// Get user's UID
userInfo, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("failed to lookup user %s: %w", username, err)
}
// Start the UI process as the console user using launchctl
// This ensures the app runs in the user's context with proper GUI access
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
log.Infof("launchCmd: %s", launchCmd.String())
// Set the user's home directory for proper macOS app behavior
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
if err := launchCmd.Start(); err != nil {
return fmt.Errorf("failed to start UI process: %w", err)
}
// Release the process so it can run independently
if err := launchCmd.Process.Release(); err != nil {
log.Warnf("failed to release UI process: %v", err)
}
log.Infof("netbird-ui started successfully for user %s", username)
return nil
}
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
log.Infof("installing pkg file: %s", path)
// Kill any existing UI processes before installation
// This ensures the postinstall script's "open $APP" will start the new version
u.killUI()
volume := "/"
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
if err := cmd.Start(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if err := cmd.Wait(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("pkg file installed successfully")
return nil
}
func (u *Installer) updateHomeBrew(ctx context.Context) error {
log.Infof("updating homebrew")
// Kill any existing UI processes before upgrade
// This ensures the new version will be started after upgrade
u.killUI()
// Homebrew must be run as a non-root user
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
// Check both Apple Silicon and Intel Mac paths
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath := "/opt/homebrew/bin/brew"
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
// Try Intel Mac path
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath = "/usr/local/bin/brew"
}
fileInfo, err := os.Stat(brewTapPath)
if err != nil {
return fmt.Errorf("error getting homebrew installation path info: %w", err)
}
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
}
// Get username from UID
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
if err != nil {
return fmt.Errorf("error looking up brew installer user: %w", err)
}
userName := brewUser.Username
// Get user HOME, required for brew to run correctly
// https://github.com/Homebrew/brew/issues/15833
homeDir := brewUser.HomeDir
// Check if netbird-ui is installed (must run as the brew user, not root)
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
uiInstalled := checkUICmd.Run() == nil
// Homebrew does not support installing specific versions
// Thus it will always update to latest and ignore targetVersion
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
if uiInstalled {
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
}
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
cmd.Env = append(os.Environ(), "HOME="+homeDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
}
log.Infof("homebrew updated successfully")
return nil
}
func (u *Installer) killUI() {
log.Infof("killing existing netbird-ui processes")
cmd := exec.Command("pkill", "-x", "netbird-ui")
if output, err := cmd.CombinedOutput(); err != nil {
// pkill returns exit code 1 if no processes matched, which is fine
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
} else {
log.Infof("netbird-ui processes killed")
}
}
func urlWithVersionArch(_ Type, version string) string {
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -1,213 +0,0 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
daemonName = "netbird.exe"
uiName = "netbird-ui.exe"
updaterBinary = "updater.exe"
msiLogFile = "msi.log"
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
)
var (
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
// for the cleanup
binaryExtensions = []string{"msi", "exe"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(daemonFolder); err != nil {
log.Errorf("failed to start UI: %v", err)
}
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
}()
if dryRun {
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
installerType, err := typeByFileExtension(installerFile)
if err != nil {
log.Debugf("%v", err)
resultErr = err
return
}
var cmd *exec.Cmd
switch installerType {
case TypeExe:
log.Infof("run exe installer: %s", installerFile)
cmd = exec.CommandContext(ctx, installerFile, "/S")
default:
installerDir := filepath.Dir(installerFile)
logPath := filepath.Join(installerDir, msiLogFile)
log.Infof("run msi installer: %s", installerFile)
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
}
cmd.Dir = filepath.Dir(installerFile)
if resultErr = cmd.Start(); resultErr != nil {
log.Errorf("error starting installer: %v", resultErr)
return
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if resultErr = cmd.Wait(); resultErr != nil {
log.Errorf("installer process finished with error: %v", resultErr)
return
}
return nil
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser(daemonFolder string) error {
uiPath := filepath.Join(daemonFolder, uiName)
log.Infof("starting netbird-ui: %s", uiPath)
// Get the active console session ID
sessionID := windows.WTSGetActiveConsoleSessionId()
if sessionID == 0xFFFFFFFF {
return fmt.Errorf("no active user session found")
}
// Get the user token for that session
var userToken windows.Token
err := windows.WTSQueryUserToken(sessionID, &userToken)
if err != nil {
return fmt.Errorf("failed to query user token: %w", err)
}
defer func() {
if err := userToken.Close(); err != nil {
log.Warnf("failed to close user token: %v", err)
}
}()
// Duplicate the token to a primary token
var primaryToken windows.Token
err = windows.DuplicateTokenEx(
userToken,
windows.MAXIMUM_ALLOWED,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return fmt.Errorf("failed to duplicate token: %w", err)
}
defer func() {
if err := primaryToken.Close(); err != nil {
log.Warnf("failed to close token: %v", err)
}
}()
// Prepare startup info
var si windows.StartupInfo
si.Cb = uint32(unsafe.Sizeof(si))
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
var pi windows.ProcessInformation
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
if err != nil {
return fmt.Errorf("failed to convert path to UTF16: %w", err)
}
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
err = windows.CreateProcessAsUser(
primaryToken,
nil,
cmdLine,
nil,
nil,
false,
creationFlags,
nil,
nil,
&si,
&pi,
)
if err != nil {
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
}
// Close handles
if err := windows.CloseHandle(pi.Process); err != nil {
log.Warnf("failed to close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Warnf("failed to close thread handle: %v", err)
}
log.Infof("netbird-ui started successfully in session %d", sessionID)
return nil
}
func urlWithVersionArch(it Type, version string) string {
var url string
if it == TypeExe {
url = exeDownloadURL
} else {
url = msiDownloadURL
}
url = strings.ReplaceAll(url, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -1,5 +0,0 @@
package installer
const (
LogFile = "installer.log"
)

View File

@@ -1,15 +0,0 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run in a new session,
// making it independent of the parent daemon process. This ensures the updater
// survives when the daemon is stopped during the pkg installation.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
}

View File

@@ -1,14 +0,0 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run detached from the parent,
// making it independent of the parent daemon process.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
}
}

View File

@@ -1,7 +0,0 @@
//go:build devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
)

View File

@@ -1,7 +0,0 @@
//go:build !devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
)

View File

@@ -1,230 +0,0 @@
package installer
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
resultFile = "result.json"
)
type Result struct {
Success bool
Error string
ExecutedAt time.Time
}
// ResultHandler handles reading and writing update results
type ResultHandler struct {
resultFile string
}
// NewResultHandler creates a new communicator with the given directory path
// The result file will be created as "result.json" in the specified directory
func NewResultHandler(installerDir string) *ResultHandler {
// Create it if it doesn't exist
// do not care if already exists
_ = os.MkdirAll(installerDir, 0o700)
rh := &ResultHandler{
resultFile: filepath.Join(installerDir, resultFile),
}
return rh
}
func (rh *ResultHandler) GetErrorResultReason() string {
result, err := rh.tryReadResult()
if err == nil && !result.Success {
return result.Error
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return ""
}
func (rh *ResultHandler) WriteSuccess() error {
result := Result{
Success: true,
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) WriteErr(errReason error) error {
result := Result{
Success: false,
Error: errReason.Error(),
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
log.Infof("start watching result: %s", rh.resultFile)
// Check if file already exists (updater finished before we started watching)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
dir := filepath.Dir(rh.resultFile)
if err := rh.waitForDirectory(ctx, dir); err != nil {
return Result{}, err
}
return rh.watchForResultFile(ctx, dir)
}
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
ticker := time.NewTicker(300 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return nil
}
}
}
}
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Error(err)
return Result{}, err
}
defer func() {
if err := watcher.Close(); err != nil {
log.Warnf("failed to close watcher: %v", err)
}
}()
if err := watcher.Add(dir); err != nil {
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
}
// Check again after setting up watcher to avoid race condition
// (file could have been created between initial check and watcher setup)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
for {
select {
case <-ctx.Done():
return Result{}, ctx.Err()
case event, ok := <-watcher.Events:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
if result, done := rh.handleWatchEvent(event); done {
return result, nil
}
case err, ok := <-watcher.Errors:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
return Result{}, fmt.Errorf("watcher error: %w", err)
}
}
}
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
if event.Name != rh.resultFile {
return Result{}, false
}
if event.Has(fsnotify.Create) {
result, err := rh.tryReadResult()
if err != nil {
log.Debugf("error while reading result: %v", err)
return result, true
}
log.Infof("installer result: %v", result)
return result, true
}
return Result{}, false
}
// Write writes the update result to a file for the UI to read
func (rh *ResultHandler) write(result Result) error {
log.Infof("write out installer result to: %s", rh.resultFile)
// Ensure directory exists
dir := filepath.Dir(rh.resultFile)
if err := os.MkdirAll(dir, 0o755); err != nil {
log.Errorf("failed to create directory %s: %v", dir, err)
return err
}
data, err := json.Marshal(result)
if err != nil {
return err
}
// Write to a temporary file first, then rename for atomic operation
tmpPath := rh.resultFile + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
log.Errorf("failed to create temp file: %s", err)
return err
}
// Atomic rename
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
log.Warnf("Failed to remove temp result file: %v", err)
}
return err
}
return nil
}
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil && !os.IsNotExist(err) {
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)
return nil
}
// tryReadResult attempts to read and validate the result file
func (rh *ResultHandler) tryReadResult() (Result, error) {
data, err := os.ReadFile(rh.resultFile)
if err != nil {
return Result{}, err
}
var result Result
if err := json.Unmarshal(data, &result); err != nil {
return Result{}, fmt.Errorf("invalid result format: %w", err)
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return result, nil
}

View File

@@ -1,14 +0,0 @@
package installer
type Type struct {
name string
downloadable bool
}
func (t Type) String() string {
return t.name
}
func (t Type) Downloadable() bool {
return t.downloadable
}

View File

@@ -1,22 +0,0 @@
package installer
import (
"context"
"os/exec"
)
var (
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
TypePKG = Type{name: "pkg", downloadable: true}
)
func TypeOfInstaller(ctx context.Context) Type {
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
_, err := cmd.Output()
if err != nil && cmd.ProcessState.ExitCode() == 1 {
// Not installed using pkg file, thus installed using Homebrew
return TypeHomebrew
}
return TypePKG
}

View File

@@ -1,51 +0,0 @@
package installer
import (
"context"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
const (
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
)
var (
TypeExe = Type{name: "EXE", downloadable: true}
TypeMSI = Type{name: "MSI", downloadable: true}
)
func TypeOfInstaller(_ context.Context) Type {
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
for _, path := range paths {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
continue
}
if err := k.Close(); err != nil {
log.Warnf("Error closing registry key: %v", err)
}
return TypeExe
}
log.Debug("No registry entry found for Netbird, assuming MSI installation")
return TypeMSI
}
func typeByFileExtension(filePath string) (Type, error) {
switch {
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
return TypeExe, nil
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
return TypeMSI, nil
default:
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
}
}

View File

@@ -1,374 +0,0 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}
func (u UpdateState) Name() string {
return "autoUpdate"
}
type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
currentVersion string
update UpdateInterface
wg sync.WaitGroup
cancel context.CancelFunc
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
triggerUpdateFn func(context.Context, string) error
}
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager, nil
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
// It works without to start the update manager.
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
reason := m.lastResultErrReason()
if reason != "" {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %s", reason),
nil,
)
}
updateState, err := m.loadAndDeleteUpdateState(ctx)
if err != nil {
if errors.Is(err, errNoUpdateState) {
return
}
log.Errorf("failed to load update state: %v", err)
return
}
log.Debugf("auto-update state loaded, %v", *updateState)
if updateState.TargetVersion == m.currentVersion {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
nil,
)
return
}
}
func (m *Manager) Start(ctx context.Context) {
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
m.update.SetDaemonVersion(m.currentVersion)
m.update.SetOnUpdateListener(func() {
select {
case m.updateChannel <- struct{}{}:
default:
}
})
go m.update.StartFetcher()
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
go m.updateLoop(ctx)
}
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if expectedVersion == "" {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
return
}
if expectedVersion == latestVersion {
m.updateToLatestVersion = true
m.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("error parsing version: %v", err)
return
}
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
return
}
m.expectedVersion = expectedSemVer
m.updateToLatestVersion = false
}
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) Stop() {
if m.cancel == nil {
return
}
m.cancel()
m.updateMutex.Lock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
m.updateMutex.Unlock()
m.wg.Wait()
}
func (m *Manager) onContextCancel() {
if m.cancel == nil {
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
m.onContextCancel()
return
case <-m.mgmUpdateChan:
case <-m.updateChannel:
log.Infof("fetched new version info")
}
m.handleUpdate(ctx)
}
}
func (m *Manager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion) {
return
}
m.lastTrigger = time.Now()
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": updateVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: updateVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
if err = m.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
}
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
// Returns nil if no state exists.
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
stateType := &UpdateState{}
m.stateManager.RegisterState(stateType)
if err := m.stateManager.LoadState(stateType); err != nil {
return nil, fmt.Errorf("load state: %w", err)
}
state := m.stateManager.GetState(stateType)
if state == nil {
return nil, errNoUpdateState
}
updateState, ok := state.(*UpdateState)
if !ok {
return nil, fmt.Errorf("failed to cast state to UpdateState")
}
if err := m.stateManager.DeleteState(updateState); err != nil {
return nil, fmt.Errorf("delete state: %w", err)
}
if err := m.stateManager.PersistState(ctx); err != nil {
return nil, fmt.Errorf("persist state: %w", err)
}
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil {
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
return false
}
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
return true
}
func (m *Manager) lastResultErrReason() string {
inst := installer.New()
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -1,214 +0,0 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

Some files were not shown because too many files have changed in this diff Show More