Compare commits

...

25 Commits

Author SHA1 Message Date
mlsmaycon
279e96e6b1 add disk encryption check 2026-01-17 19:56:50 +01:00
Maycon Santos
245481f33b [client] fix: client/Dockerfile to reduce vulnerabilities (#5119)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-01-16 18:05:41 +01:00
shuuri-labs
b352ab84c0 Feat/quickstart reverse proxy assistant (#5100)
* add external reverse proxy config steps to quickstart script

* remove generated files

* - Remove 'press enter' prompt from post-traefik config since traefik requires no manual config
- Improve npm flow (ask users for docker network, user container names in config)

* fixes for npm flow

* nginx flow fixes

* caddy flow fixes

* Consolidate NPM_NETWORK, NGINX_NETWORK, CADDY_NETWORK into single
EXTERNAL_PROXY_NETWORK variable. Add read_proxy_docker_network()
function that prompts for Docker network for options 2-4 (Nginx,
NPM, Caddy). Generated configs now use container names when a
Docker network is specified.

* fix https for traefik

* fix sonar code smells

* fix sonar smell (add return to render_dashboard_env)

* added tls instructions to nginx flow

* removed unused bind_addr variable from quickstart.sh

* Refactor getting-started.sh for improved maintainability

Break down large functions into focused, single-responsibility components:
- Split init_environment() into 6 initialization functions
- Split print_post_setup_instructions() into 6 proxy-specific functions
- Add section headers for better code organization
- Fix 3 code smell issues (unused bind_addr variables)
- Add TLS certificate documentation for Nginx
- Link reverse proxy names to docs sections

Reduces largest function from 205 to ~90 lines while maintaining
single-file distribution. No functional changes.

* - Remove duplicate network display logic in Traefik instructions
- Use upstream_host instead of bind_addr for NPM forward hostname
- Use upstream_host instead of bind_addr in manual proxy route examples
- Prevents displaying invalid 0.0.0.0 as connection target in setup instructions

* add wait_management_direct to caddy flow to ensure script waits until containers are running/passing healthchecks before reporting 'done!'
2026-01-16 17:42:28 +01:00
ressys1978
3ce5d6a4f8 [management] Add idp timeout env variable (#4647)
Introduced the NETBIRD_IDP_TIMEOUT environment variable to the management service. This allows configuring a timeout for supported IDPs. If the variable is unset or contains an invalid value, a default timeout of 10 seconds is used as a fallback.

This is needed for larger IDP environments where 10s is just not enough time.
2026-01-16 16:23:37 +01:00
Misha Bragin
4c2eb2af73 [management] Skip email_verified if not present (#5118) 2026-01-16 16:01:39 +01:00
Misha Bragin
daf1449174 [client] Remove duplicate audiences check (#5117) 2026-01-16 14:25:02 +02:00
Misha Bragin
1ff7abe909 [management, client] Fix SSH server audience validator (#5105)
* **New Features**
  * SSH server JWT validation now accepts multiple audiences with backward-compatible handling of the previous single-audience setting and a guard ensuring at least one audience is configured.
* **Tests**
  * Test suites updated and new tests added to cover multiple-audience scenarios and compatibility with existing behavior.
* **Other**
  * Startup logging enhanced to report configured audiences for JWT auth.
2026-01-16 12:28:17 +01:00
Bethuel Mmbaga
067c77e49e [management] Add custom dns zones (#4849) 2026-01-16 12:12:05 +03:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
Viktor Liu
520d9c66cf [client] Fix netstack upstream dns and add wasm debug methods (#4648) 2026-01-14 13:56:16 +01:00
Misha Bragin
ff10498a8b Feature/embedded STUN (#5062) 2026-01-14 13:13:30 +01:00
Zoltan Papp
00b747ad5d Handle fallback for invalid loginuid in ui-post-install.sh. (#5099) 2026-01-14 09:53:14 +01:00
Zoltan Papp
d9118eb239 [client] Fix WASM peer connection to lazy peers (#5097)
WASM peers now properly initiate relay connections instead of waiting for offers that lazy peers won't send.
2026-01-13 13:33:15 +01:00
Nima Sadeghifard
94de656fae [misc] Add hiring announcement with link to careers.netbird.io (#5095) 2026-01-12 19:06:28 +01:00
Misha Bragin
37abab8b69 [management] Check config compatibility (#5087)
* Enforce HttpConfig overwrite when embeddedIdp is enabled

* Disable offline_access scope in dashboard by default

* Add group propagation foundation to embedded idp

* Require groups scope in dex config for okt and pocket

* remove offline_access from device default scopes
2026-01-12 17:09:03 +01:00
Viktor Liu
b12c084a50 [client] Fall through dns chain for custom dns zones (#5081) 2026-01-12 13:56:39 +01:00
Viktor Liu
394ad19507 [client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046) 2026-01-12 12:35:38 +01:00
Misha Bragin
614e7d5b90 Validate OIDC issuer when creating or updating (#5074) 2026-01-09 09:45:43 -05:00
Misha Bragin
f7967f9ae3 Feature/resolve local jwks keys (#5073) 2026-01-09 09:41:27 -05:00
Vlad
684fc0d2a2 [management] fix the issue with duplicated peers with the same key (#5053) 2026-01-09 11:49:26 +01:00
Viktor Liu
0ad0c81899 [client] Reorder userspace ACL checks to fail faster for better performance (#4226) 2026-01-09 09:13:04 +01:00
Viktor Liu
e8863fbb55 [client] Add non-root ICMP support to userspace firewall forwarder (#4792) 2026-01-09 02:53:37 +08:00
Zoltan Papp
9c9d8e17d7 Revert "Revert "[relay] Update GO version and QUIC version (#4736)" (#5055)" (#5071)
This reverts commit 24df442198.
2026-01-08 18:58:22 +01:00
210 changed files with 12460 additions and 2063 deletions

View File

@@ -1,15 +1,15 @@
FROM golang:1.23-bullseye FROM golang:1.25-bookworm
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\ && apt-get -y install --no-install-recommends\
gettext-base=0.21-4 \ gettext-base=0.21-12 \
iptables=1.8.7-1 \ iptables=1.8.9-2 \
libgl1-mesa-dev=20.3.5-1 \ libgl1-mesa-dev=22.3.6-1+deb12u1 \
xorg-dev=1:7.7+22 \ xorg-dev=1:7.7+23 \
libayatana-appindicator3-dev=0.5.5-2+deb11u2 \ libayatana-appindicator3-dev=0.5.92-1 \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& go install -v golang.org/x/tools/gopls@v0.18.1 && go install -v golang.org/x/tools/gopls@latest
WORKDIR /app WORKDIR /app

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2" release: "14.2"
prepare: | prepare: |
pkg install -y curl pkgconf xorg pkg install -y curl pkgconf xorg
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL" GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL" curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL" tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \ -e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \ -e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \ -e CONTAINER=${CONTAINER} \
golang:1.24-alpine \ golang:1.25-alpine \
sh -c ' \ sh -c ' \
apk update; apk add --no-cache \ apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \ ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -259,7 +259,7 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \ go test ${{ matrix.raceFlag }} \
-exec 'sudo' \ -exec 'sudo' \
-timeout 10m ./relay/... ./shared/relay/... -timeout 10m -p 1 ./relay/... ./shared/relay/...
test_signal: test_signal:
name: "Signal / Unit" name: "Signal / Unit"

View File

@@ -52,7 +52,10 @@ jobs:
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v4 uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with: with:
version: latest version: latest
args: --timeout=12m --out-format colored-line-number skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=12m

View File

@@ -63,7 +63,7 @@ jobs:
pkg install -y git curl portlint go pkg install -y git curl portlint go
# Install Go for building # Install Go for building
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz" GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL" GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL" curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL" tar -C /usr/local -xzf "$GO_TARBALL"

View File

@@ -14,6 +14,9 @@ jobs:
js_lint: js_lint:
name: "JS / Lint" name: "JS / Lint"
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
GOOS: js
GOARCH: wasm
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -24,16 +27,14 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint - name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
with: with:
version: latest version: latest
install-mode: binary install-mode: binary
skip-cache: true skip-cache: true
skip-pkg-cache: true skip-save-cache: true
skip-build-cache: true cache-invalidation-interval: 0
- name: Run golangci-lint for WASM working-directory: ./client
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true continue-on-error: true
js_build: js_build:

View File

@@ -1,139 +1,124 @@
run: version: "2"
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 6m
# This file contains only configs which differ from defaults.
# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
linters-settings:
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: false
gosec:
includes:
- G101 # Look for hard coded credentials
#- G102 # Bind to all interfaces
- G103 # Audit the use of unsafe block
- G104 # Audit errors not checked
- G106 # Audit the use of ssh.InsecureIgnoreHostKey
#- G107 # Url provided to HTTP request as taint input
- G108 # Profiling endpoint automatically exposed on /debug/pprof
- G109 # Potential Integer overflow made by strconv.Atoi result conversion to int16/32
- G110 # Potential DoS vulnerability via decompression bomb
- G111 # Potential directory traversal
#- G112 # Potential slowloris attack
- G113 # Usage of Rat.SetString in math/big with an overflow (CVE-2022-23772)
#- G114 # Use of net/http serve function that has no support for setting timeouts
- G201 # SQL query construction using format string
- G202 # SQL query construction using string concatenation
- G203 # Use of unescaped data in HTML templates
#- G204 # Audit use of command execution
- G301 # Poor file permissions used when creating a directory
- G302 # Poor file permissions used with chmod
- G303 # Creating tempfile using a predictable path
- G304 # File path provided as taint input
- G305 # File traversal when extracting zip/tar archive
- G306 # Poor file permissions used when writing to a new file
- G307 # Poor file permissions used when creating a file with os.Create
#- G401 # Detect the usage of DES, RC4, MD5 or SHA1
#- G402 # Look for bad TLS connection settings
- G403 # Ensure minimum RSA key length of 2048 bits
#- G404 # Insecure random number source (rand)
#- G501 # Import blocklist: crypto/md5
- G502 # Import blocklist: crypto/des
- G503 # Import blocklist: crypto/rc4
- G504 # Import blocklist: net/http/cgi
#- G505 # Import blocklist: crypto/sha1
- G601 # Implicit memory aliasing of items from a range statement
- G602 # Slice access out of bounds
gocritic:
disabled-checks:
- commentFormatting
- captLocal
- deprecatedComment
govet:
# Enable all analyzers.
# Default: false
enable-all: false
enable:
- nilness
revive:
rules:
- name: exported
severity: warning
disabled: false
arguments:
- "checkPrivateReceivers"
- "sayRepetitiveInsteadOfStutters"
tenv:
# The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
# Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
# Default: false
all: true
linters: linters:
disable-all: true default: none
enable: enable:
## enabled by default - bodyclose
- errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases - dupword
- gosimple # specializes in simplifying a code - durationcheck
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - errcheck
- ineffassign # detects when assignments to existing variables are not used - forbidigo
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks - gocritic
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17. - gosec
- typecheck # like the front-end of a Go compiler, parses and type-checks Go code - govet
- unused # checks for unused constants, variables, functions and types - ineffassign
## disable by default but the have interesting results so lets add them - mirror
- bodyclose # checks whether HTTP response body is closed successfully - misspell
- dupword # dupword checks for duplicate words in the source code - nilerr
- durationcheck # durationcheck checks for two durations multiplied together - nilnil
- forbidigo # forbidigo forbids identifiers - predeclared
- gocritic # provides diagnostics that check for bugs, performance and style issues - revive
- gosec # inspects source code for security problems - sqlclosecheck
- mirror # mirror reports wrong mirror patterns of bytes/strings usage - staticcheck
- misspell # misspess finds commonly misspelled English words in comments - unused
- nilerr # finds the code that returns nil even if it checks that the error is not nil - wastedassign
- nilnil # checks that there is no simultaneous return of nil error and an invalid value settings:
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers errcheck:
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. check-type-assertions: false
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed gocritic:
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers. disabled-checks:
- wastedassign # wastedassign finds wasted assignment statements - commentFormatting
- captLocal
- deprecatedComment
gosec:
includes:
- G101
- G103
- G104
- G106
- G108
- G109
- G110
- G111
- G201
- G202
- G203
- G301
- G302
- G303
- G304
- G305
- G306
- G307
- G403
- G502
- G503
- G504
- G601
- G602
govet:
enable:
- nilness
enable-all: false
revive:
rules:
- name: exported
arguments:
- checkPrivateReceivers
- sayRepetitiveInsteadOfStutters
severity: warning
disabled: false
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- linters:
- forbidigo
path: management/cmd/root\.go
- linters:
- forbidigo
path: signal/cmd/root\.go
- linters:
- unused
path: sharedsock/filter\.go
- linters:
- unused
path: client/firewall/iptables/rule\.go
- linters:
- gosec
- mirror
path: test\.go
- linters:
- nilnil
path: mock\.go
- linters:
- staticcheck
text: grpc.DialContext is deprecated
- linters:
- staticcheck
text: grpc.WithBlock is deprecated
- linters:
- staticcheck
text: "QF1001"
- linters:
- staticcheck
text: "QF1008"
- linters:
- staticcheck
text: "QF1012"
paths:
- third_party$
- builtin$
- examples$
issues: issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 5 max-same-issues: 5
formatters:
exclude-rules: exclusions:
# allow fmt generated: lax
- path: management/cmd/root\.go paths:
linters: forbidigo - third_party$
- path: signal/cmd/root\.go - builtin$
linters: forbidigo - examples$
- path: sharedsock/filter\.go
linters:
- unused
- path: client/firewall/iptables/rule\.go
linters:
- unused
- path: test\.go
linters:
- mirror
- gosec
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -38,6 +38,11 @@
</strong> </strong>
<br> <br>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
</strong>
<br>
<br>
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest"> <a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird terraform provider New: NetBird terraform provider
</a> </a>

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2 FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

View File

@@ -136,6 +136,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0]) level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN { if level == proto.LogLevel_UNKNOWN {
//nolint
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0]) return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
} }
@@ -313,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
profName = activeProf.Name profName = activeProf.Name
} }
statusOutputString = nbstatus.ParseToFullDetailSummary( overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), statusOutputString = overview.FullDetailSummary()
)
} }
return statusOutputString return statusOutputString
} }

View File

@@ -81,6 +81,7 @@ var loginCmd = &cobra.Command{
func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error { func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey string, activeProf *profilemanager.Profile, username string, pm *profilemanager.ProfileManager) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)
@@ -206,6 +207,7 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
func switchProfile(ctx context.Context, profileName string, username string) error { func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)

View File

@@ -1,5 +1,4 @@
//go:build pprof //go:build pprof
// +build pprof
package cmd package cmd

View File

@@ -390,6 +390,7 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil { if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)

View File

@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder) statusOutputString = outputInformationHolder.FullDetailSummary()
case jsonFlag: case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder) statusOutputString, err = outputInformationHolder.JSON()
case yamlFlag: case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) statusOutputString, err = outputInformationHolder.YAML()
default: default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
} }
if err != nil { if err != nil {
@@ -124,6 +124,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)

View File

@@ -89,9 +89,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)

View File

@@ -216,6 +216,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+ return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+ "If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err) "\nnetbird service install \nnetbird service start\n", err)

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh" sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
var ( var (
@@ -38,6 +39,7 @@ type Client struct {
setupKey string setupKey string
jwtToken string jwtToken string
connect *internal.ConnectClient connect *internal.ConnectClient
recorder *peer.Status
} }
// Options configures a new Client. // Options configures a new Client.
@@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error { func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.cancel != nil { if c.connect != nil {
return ErrClientAlreadyStarted return ErrClientAlreadyStarted
} }
ctx := internal.CtxInitState(context.Background()) ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
@@ -173,7 +181,9 @@ func (c *Client) Start(startCtx context.Context) error {
} }
recorder := peer.NewRecorder(c.config.ManagementURL.String()) recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false) client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
@@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error {
} }
c.connect = client c.connect = client
c.cancel = cancel
return nil return nil
} }
@@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted return ErrClientNotStarted
} }
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1) done := make(chan error, 1)
connect := c.connect
go func() { go func() {
done <- c.connect.Stop() done <- connect.Stop()
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.cancel = nil c.connect = nil
return ctx.Err() return ctx.Err()
case err := <-done: case err := <-done:
c.cancel = nil c.connect = nil
if err != nil { if err != nil {
return fmt.Errorf("stop: %w", err) return fmt.Errorf("stop: %w", err)
} }
@@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client {
} }
} }
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys. // VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network, // Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures. // ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -386,11 +386,8 @@ func (m *aclManager) updateState() {
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is 0.0.0.0 // don't use IP matching if IP is 0.0.0.0
if ip.IsUnspecified() { matchByIP := !ip.IsUnspecified()
matchByIP = false
}
if matchByIP { if matchByIP {
if ipsetName != "" { if ipsetName != "" {

View File

@@ -161,7 +161,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule) t.Logf(" [%d] %s", i, rule)
} }
var denyRuleIndex, acceptRuleIndex int = -1, -1 var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules { for i, rule := range rules {
if strings.Contains(rule, "DROP") { if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule) t.Logf("Found DROP rule at index %d: %s", i, rule)

View File

@@ -198,7 +198,7 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
t.Logf("Found %d rules in nftables chain", len(rules)) t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept // Find the accept and deny rules and verify deny comes before accept
var acceptRuleIndex, denyRuleIndex int = -1, -1 var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules { for i, rule := range rules {
hasAcceptHTTPSet := false hasAcceptHTTPSet := false
hasDenyHTTPSet := false hasDenyHTTPSet := false
@@ -208,11 +208,13 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
for _, e := range rule.Exprs { for _, e := range rule.Exprs {
// Check for set lookup // Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok { if lookup, ok := e.(*expr.Lookup); ok {
if lookup.SetName == "accept-http" { switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true hasAcceptHTTPSet = true
} else if lookup.SetName == "deny-http" { case "deny-http":
hasDenyHTTPSet = true hasDenyHTTPSet = true
} }
} }
// Check for port 80 // Check for port 80
if cmp, ok := e.(*expr.Cmp); ok { if cmp, ok := e.(*expr.Cmp); ok {
@@ -222,9 +224,10 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
} }
// Check for verdict // Check for verdict
if verdict, ok := e.(*expr.Verdict); ok { if verdict, ok := e.(*expr.Verdict); ok {
if verdict.Kind == expr.VerdictAccept { switch verdict.Kind {
case expr.VerdictAccept:
action = "ACCEPT" action = "ACCEPT"
} else if verdict.Kind == expr.VerdictDrop { case expr.VerdictDrop:
action = "DROP" action = "DROP"
} }
} }

View File

@@ -29,7 +29,7 @@ import (
) )
const ( const (
layerTypeAll = 0 layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40 ipTCPHeaderMinSize = 40
@@ -262,10 +262,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix := iface.Address().Network
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering( rule, err := m.addRouteFiltering(
@@ -439,19 +436,7 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort r.sPort = sPort
r.dPort = dPort r.dPort = dPort
switch proto { r.protoLayer = protoToLayer(proto, r.ipLayer)
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock() m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet var targetMap map[netip.Addr]RuleSet
@@ -496,16 +481,17 @@ func (m *Manager) addRouteFiltering(
} }
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
proto: proto, protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
} }
if destination.IsPrefix() { if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix} rule.destinations = []netip.Prefix{destination.Prefix}
@@ -795,7 +781,7 @@ func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeade
pseudoSum += uint32(d.ip4.Protocol) pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength) pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum var sum = pseudoSum
for i := 0; i < tcpLength-1; i += 2 { for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
} }
@@ -945,7 +931,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData) ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked { if blocked {
_, pnum := getProtocolFromPacket(d) pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -1010,20 +996,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false return false
} }
proto, pnum := getProtocolFromPacket(d) protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
if !pass { if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, proto, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: ruleID, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: proto,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
@@ -1052,16 +1040,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) { func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return firewall.ProtocolTCP, nftypes.TCP return nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return firewall.ProtocolUDP, nftypes.UDP return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP, nftypes.ICMP return nftypes.ICMP
default: default:
return firewall.ProtocolALL, nftypes.ProtocolUnknown return nftypes.ProtocolUnknown
} }
} }
@@ -1233,19 +1238,30 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
} }
// routeACLsPass returns true if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches { if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
destMatched := false destMatched := false
for _, dst := range rule.destinations { for _, dst := range rule.destinations {
if dst.Contains(dstAddr) { if dst.Contains(dstAddr) {
@@ -1264,21 +1280,8 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break break
} }
} }
if !sourceMatched {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto { return sourceMatched
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
} }
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
} }
} }
} }

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100") dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything) // Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic") require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err) require.NoError(t, err)
// Now the packet should be allowed // Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
} }

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100") dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100") dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Check that all original prefixes are still included // Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100") dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50") dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases { for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc) require.Equal(t, tc.expected, isAllowed, tc.desc)
} }
} }

View File

@@ -2,6 +2,7 @@ package forwarder
import ( import (
"fmt" "fmt"
"sync/atomic"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
@@ -16,7 +17,7 @@ type endpoint struct {
logger *nblog.Logger logger *nblog.Logger
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
device *wgdevice.Device device *wgdevice.Device
mtu uint32 mtu atomic.Uint32
} }
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool {
} }
func (e *endpoint) MTU() uint32 { func (e *endpoint) MTU() uint32 {
return e.mtu return e.mtu.Load()
} }
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
@@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true return true
} }
func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}
func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}
func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}
type epID stack.TransportEndpointID type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {

View File

@@ -7,6 +7,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -35,14 +36,16 @@ type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection // ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip tcpip.Address ip tcpip.Address
netstack bool netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{ endpoint := &endpoint{
logger: logger, logger: logger,
device: iface.GetWGDevice(), device: iface.GetWGDevice(),
mtu: uint32(mtu),
} }
endpoint.mtu.Store(uint32(mtu))
if err := s.CreateNIC(nicID, endpoint); err != nil { if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err) return nil, fmt.Errorf("create NIC: %v", err)
@@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger, logger: logger,
flowLogger: flowLogger, flowLogger: flowLogger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger), udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()), ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
f.checkICMPCapability()
log.Debugf("forwarder: Initialization complete with NIC %d", nicID) log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil return f, nil
} }
@@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort, DstPort: dstPort,
} }
} }
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}
f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}

View File

@@ -2,8 +2,11 @@ package forwarder
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip" "net/netip"
"os/exec"
"runtime"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -14,30 +17,95 @@ import (
) )
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New() flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0) f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) // For Echo Requests, send and wait for response
if icmpHdr.Type() == header.ICMPv4Echo {
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
if !f.hasRawICMPAccess {
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
return false
}
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
if err != nil {
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
return true
}
if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
}
return true
}
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
select {
case f.pingSemaphore <- struct{}{}:
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
rxBytes := pkt.Size()
go func() {
defer func() { <-f.pingSemaphore }()
if f.hasRawICMPAccess {
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
} else {
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
}
}()
default:
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return true
}
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
// The caller is responsible for closing the returned connection.
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(f.ctx, timeout)
defer cancel() defer cancel()
lc := net.ListenConfig{} lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err) return nil, fmt.Errorf("create ICMP socket: %w", err)
}
// This will make netstack reply on behalf of the original destination, that's ok for now dstIP := f.determineDialAddr(id.LocalAddress)
return false dst := &net.IPAddr{IP: dstIP}
if _, err = conn.WriteTo(payload, dst); err != nil {
if closeErr := conn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
}
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
sendTime := time.Now()
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
if err != nil {
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
return
} }
defer func() { defer func() {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
@@ -45,38 +113,22 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
} }
}() }()
dstIP := f.determineDialAddr(id.LocalAddress) txBytes := f.handleEchoResponse(conn, id)
dst := &net.IPAddr{IP: dstIP} rtt := time.Since(sendTime).Round(10 * time.Microsecond)
fullPacket := stack.PayloadSince(pkt.TransportHeader()) f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
payload := fullPacket.AsSlice() epID(id), icmpType, icmpCode, rtt)
if _, err = conn.WriteTo(payload, dst); err != nil { f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int { func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err) f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0 return 0
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu.Load())
n, _, err := conn.ReadFrom(response) n, _, err := conn.ReadFrom(response)
if err != nil { if err != nil {
if !isTimeout(err) { if !isTimeout(err) {
@@ -85,31 +137,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
return 0 return 0
} }
ipHdr := make([]byte, header.IPv4MinimumSize) return f.injectICMPReply(id, response[:n])
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
return 0
}
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
@@ -152,3 +180,95 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)
} }
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
// This is used as a fallback when raw socket access is not available.
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
dstIP := f.determineDialAddr(id.LocalAddress)
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
pingStart := time.Now()
if err := cmd.Run(); err != nil {
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
icmpType, icmpCode, err)
return
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// buildPingCommand creates a platform-specific ping command.
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
timeoutSec := int(timeout.Seconds())
if timeoutSec < 1 {
timeoutSec = 1
}
switch runtime.GOOS {
case "linux", "android":
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "darwin", "ios":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
case "freebsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
case "openbsd", "netbsd":
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
case "windows":
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
default:
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
}
}
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
// Returns the size of the injected packet.
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
replyICMP := make([]byte, len(icmpData))
copy(replyICMP, icmpData)
replyICMPHdr := header.ICMPv4(replyICMP)
replyICMPHdr.SetType(header.ICMPv4EchoReply)
replyICMPHdr.SetChecksum(0)
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
return f.injectICMPReply(id, replyICMP)
}
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
// Returns the total size of the injected packet, or 0 if injection failed.
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, icmpPayload...)
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
return 0
}
return len(fullPacket)
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
@@ -131,10 +132,10 @@ func (f *udpForwarder) cleanup() {
} }
// handleUDP is called by the UDP forwarder for new packets // handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
if f.ctx.Err() != nil { if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet") f.logger.Trace("forwarder: context done, dropping UDP packet")
return return false
} }
id := r.ID() id := r.ID()
@@ -144,7 +145,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return return true
} }
flowID := uuid.New() flowID := uuid.New()
@@ -162,7 +163,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err != nil { if err != nil {
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return false
} }
// Create wait queue for blocking syscalls // Create wait queue for blocking syscalls
@@ -173,10 +174,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return false
} }
inConn := gonet.NewUDPConn(f.stack, &wq, ep) inConn := gonet.NewUDPConn(&wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx) connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{ pConn := &udpPacketConn{
@@ -199,7 +200,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return true
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
@@ -208,6 +209,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
return true
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
@@ -348,7 +350,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
} }
func isClosedError(err error) bool { func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
} }
func isTimeout(err error) bool { func isTimeout(err error) bool {

View File

@@ -130,6 +130,7 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ { for i := 0; i < 8192; i++ {
// #nosec G602 -- bitmap is defined as [8192]uint32, loop range is correct
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
} }

View File

@@ -218,7 +218,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// nolint:gosimple // nolint:gosimple
_, _ = mapManager.localIPs[ip.String()] _ = mapManager.localIPs[ip.String()]
} }
}) })
@@ -227,7 +227,7 @@ func BenchmarkIPChecks(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// nolint:gosimple // nolint:gosimple
_, _ = mapManager.localIPs[ip.String()] _ = mapManager.localIPs[ip.String()]
} }
}) })
} }

View File

@@ -168,6 +168,15 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
} }
} }
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
}
func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {

View File

@@ -234,9 +234,10 @@ func TestInboundPortDNATNegative(t *testing.T) {
require.False(t, translated, "Packet should NOT be translated for %s", tc.name) require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet) d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP { switch tc.protocol {
case layers.IPProtocolTCP:
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP { case layers.IPProtocolUDP:
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
} }
}) })

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix sources []netip.Prefix
dstSet firewall.Set dstSet firewall.Set
destinations []netip.Prefix destinations []netip.Prefix
proto firewall.Protocol protoLayer gopacket.LayerType
srcPort *firewall.Port srcPort *firewall.Port
dstPort *firewall.Port dstPort *firewall.Port
action firewall.Action action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d) protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
strId := string(id) strId := string(id)
if id == nil { if id == nil {

View File

@@ -27,8 +27,23 @@ type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
} }
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool) if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
}
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
buf := bufs[0]
size, ep, err := conn.ReadFromUDPAddrPort(buf)
if err != nil {
return 0, err
}
sizes[0] = size
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
eps[0] = stdEp
return 1, nil
}
} }
// ICEBind is a bind implementation with two main features: // ICEBind is a bind implementation with two main features:

View File

@@ -1,6 +1,3 @@
//go:build ios
// +build ios
package device package device
import ( import (

View File

@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil return syncResponse, nil
} }
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status // Status returns the current client status
func (c *ConnectClient) Status() StatusType { func (c *ConnectClient) Status() StatusType {
if c == nil { if c == nil {

View File

@@ -507,15 +507,13 @@ func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string {
if p.Base == expr.PayloadBaseNetworkHeader { if p.Base == expr.PayloadBaseNetworkHeader {
switch p.Offset { switch p.Offset {
case 12: case 12:
if p.Len == 4 { switch p.Len {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) case 4, 2:
} else if p.Len == 2 {
return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} }
case 16: case 16:
if p.Len == 4 { switch p.Len {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) case 4, 2:
} else if p.Len == 2 {
return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data))
} }
} }

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
if zone.SkipPTRProcess { if zone.NonAuthoritative {
continue continue
} }
for _, record := range zone.Records { for _, record := range zone.Records {

View File

@@ -3,17 +3,21 @@ package dns
import ( import (
"fmt" "fmt"
"slices" "slices"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
) )
const ( const (
PriorityMgmtCache = 150 PriorityMgmtCache = 150
PriorityLocal = 100 PriorityDNSRoute = 100
PriorityDNSRoute = 75 PriorityLocal = 75
PriorityUpstream = 50 PriorityUpstream = 50
PriorityDefault = 1 PriorityDefault = 1
PriorityFallback = -100 PriorityFallback = -100
@@ -43,7 +47,23 @@ type HandlerChain struct {
type ResponseWriterChain struct { type ResponseWriterChain struct {
dns.ResponseWriter dns.ResponseWriter
origPattern string origPattern string
requestID string
shouldContinue bool shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
} }
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true w.shouldContinue = true
return nil return nil
} }
w.response = m
return w.ResponseWriter.WriteMsg(m) return w.ResponseWriter.WriteMsg(m)
} }
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry) pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
} }
// findHandlerPosition determines where to insert a new handler based on priority and specificity // findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break break
} }
} }
} }
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
qname := strings.ToLower(r.Question[0].Name) startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry) if !c.isHandlerMatch(qname, entry) {
continue
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
} }
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
}
continue
}
c.logResponse(logger, chainWriter, qname, startTime)
return
} }
// No handler matched or all handlers passed // No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname) logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused) resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch { switch {
case entry.Pattern == ".": case entry.Pattern == ".":

View File

@@ -1,30 +1,52 @@
package local package local
import ( import (
"context"
"errors"
"fmt" "fmt"
"net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct { type Resolver struct {
mu sync.RWMutex mu sync.RWMutex
records map[dns.Question][]dns.RR records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{} domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
cancel context.CancelFunc
} }
func NewResolver() *Resolver { func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{ return &Resolver{
records: make(map[dns.Question][]dns.RR), records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}), domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
} }
} }
@@ -37,7 +59,18 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
} }
func (d *Resolver) Stop() {} func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
}
// ID returns the unique handler ID // ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID { func (d *Resolver) ID() types.HandlerID {
@@ -48,35 +81,85 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 { if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question") logger.Debug("received local resolver request with no question")
return return
} }
question := r.Question[0] question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name)) question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{} replyMessage := &dns.Msg{}
replyMessage.SetReply(r) replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true replyMessage.RecursionAvailable = true
// lookup all records matching the question result := d.lookupRecords(logger, question)
records := d.lookupRecords(question) replyMessage.Authoritative = !result.hasExternalData
if len(records) > 0 { replyMessage.Answer = result.records
replyMessage.Rcode = dns.RcodeSuccess replyMessage.Rcode = d.determineRcode(question, result)
replyMessage.Answer = append(replyMessage.Answer, records...)
} else { if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
// Check if we have any records for this domain name with different types d.continueToNext(logger, w, r)
if d.hasRecordsForDomain(domain.Domain(question.Name)) { return
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
} }
if err := w.WriteMsg(replyMessage); err != nil { if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err) logger.Warnf("failed to write the local resolver response: %v", err)
}
}
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
} }
} }
@@ -89,8 +172,27 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
return exists return exists
} }
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r. // lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock() d.mu.RLock()
records, found := d.records[question] records, found := d.records[question]
@@ -98,10 +200,14 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RUnlock() d.mu.RUnlock()
// alternatively check if we have a cname // alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME { if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME cnameQuestion := dns.Question{
return d.lookupRecords(question) Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
} }
return nil return lookupResult{rcode: dns.RcodeNameError}
} }
recordsCopy := slices.Clone(records) recordsCopy := slices.Clone(records)
@@ -119,20 +225,178 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.Unlock() d.mu.Unlock()
} }
return recordsCopy return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
} }
func (d *Resolver) Update(update []nbdns.SimpleRecord) { // lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName)) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
maps.Clear(d.records) maps.Clear(d.records)
maps.Clear(d.domains) maps.Clear(d.domains)
maps.Clear(d.zones)
for _, rec := range update { for _, zone := range customZones {
if err := d.registerRecord(rec); err != nil { zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
log.Warnf("failed to register the record (%s): %v", rec, err) d.zones[zoneDomain] = zone.NonAuthoritative
continue
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
} }
} }
} }

View File

@@ -1,8 +1,14 @@
package local package local
import ( import (
"context"
"fmt"
"net"
"net/netip"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -12,6 +18,18 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
// mockResolver implements resolver for testing
type mockResolver struct {
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
}
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
if m.lookupFunc != nil {
return m.lookupFunc(ctx, network, host)
}
return nil, nil
}
func TestLocalResolver_ServeDNS(t *testing.T) { func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{ recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.", Name: "peera.netbird.cloud.",
@@ -106,11 +124,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1} zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
update2 := []nbdns.SimpleRecord{record2} zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
// Apply first update // Apply first update
resolver.Update(update1) resolver.Update(zone1)
// Verify first update // Verify first update
resolver.mu.RLock() resolver.mu.RLock()
@@ -122,7 +140,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update // Apply second update
resolver.Update(update2) resolver.Update(zone2)
// Verify second update // Verify second update
resolver.mu.RLock() resolver.mu.RLock()
@@ -151,10 +169,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2", Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
} }
update := []nbdns.SimpleRecord{record1, record2} zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
// Apply update with both records // Apply update with both records
resolver.Update(update) resolver.Update(zones)
// Create question that matches both records // Create question that matches both records
question := dns.Question{ question := dns.Question{
@@ -195,10 +213,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3", Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
} }
update := []nbdns.SimpleRecord{record1, record2, record3} zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
// Apply update with all three records // Apply update with all three records
resolver.Update(update) resolver.Update(zones)
msg := new(dns.Msg).SetQuestion(recordName, recordType) msg := new(dns.Msg).SetQuestion(recordName, recordType)
@@ -264,7 +282,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
} }
// Update resolver with the records // Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}) resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
testCases := []struct { testCases := []struct {
name string name string
@@ -379,7 +397,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
} }
// Update resolver with both records // Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}) resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
testCases := []struct { testCases := []struct {
name string name string
@@ -476,6 +494,20 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
// with 0 records instead of NXDOMAIN // with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
// Mock external resolver for CNAME target resolution
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "target.example.com." {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
if network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
recordA := nbdns.SimpleRecord{ recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.", Name: "example.netbird.cloud.",
@@ -493,7 +525,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
RData: "target.example.com.", RData: "target.example.com.",
} }
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}) resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
testCases := []struct { testCases := []struct {
name string name string
@@ -582,3 +614,808 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
}) })
} }
} }
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("simple internal CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "target.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "192.168.1.1", a.A.String())
})
t.Run("multi-hop CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3)
})
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok)
})
}
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
t.Run("chain at max depth resolves", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 7 CNAMEs (under max of 8)
for i := 1; i <= 7; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("hop%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("hop%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 8)
})
t.Run("chain exceeding max depth stops", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 10 CNAMEs (exceeds max of 8)
for i := 1; i <= 10; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("deep%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("deep%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
// Should NOT have the final A record (chain too deep)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
},
}})
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
}
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "93.184.216.34", a.A.String())
})
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
aaaa, ok := resp.Answer[1].(*dns.AAAA)
require.True(t, ok)
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
})
t.Run("concurrent external resolution", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
var wg sync.WaitGroup
results := make([]*dns.Msg, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
results[idx] = resp
}(i)
}
wg.Wait()
for i, resp := range results {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
}
})
}
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update sets zones correctly", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}},
{Domain: "test.local."},
})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("other.example.com."))
assert.True(t, resolver.isInManagedZone("sub.test.local."))
assert.False(t, resolver.isInManagedZone("external.com."))
})
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
})
t.Run("Update clears zones", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
resolver.Update(nil)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
}
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
})
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.other.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
})
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
resolver := NewResolver()
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." {
if network == "ip6" {
// No AAAA records
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
if network == "ip4" {
// But A records exist - domain exists
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
// Table-driven test for all external resolution outcomes
externalCases := []struct {
name string
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
expectedRcode int
expectedAnswer int
}{
{
name: "external NXDOMAIN (both A and AAAA not found)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeNameError,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (temporary error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTemporary: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (timeout)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTimeout: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (generic error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, fmt.Errorf("connection refused")
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external success with IPs",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeSuccess,
expectedAnswer: 2, // CNAME + A
},
}
for _, tc := range externalCases {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
if tc.expectedAnswer > 0 {
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "first answer should be CNAME")
}
})
}
}
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
// trigger fallthrough (Zero bit set) when no records match
func TestLocalResolver_Fallthrough(t *testing.T) {
resolver := NewResolver()
record := nbdns.SimpleRecord{
Name: "existing.custom.zone.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}
testCases := []struct {
name string
zones []nbdns.CustomZone
queryName string
expectFallthrough bool
expectRecord bool
}{
{
name: "Authoritative zone returns NXDOMAIN without fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: false,
expectRecord: false,
},
{
name: "Non-authoritative zone triggers fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: true,
expectRecord: false,
},
{
name: "Record found in non-authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
{
name: "Record found in authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resolver.Update(tc.zones)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response")
if tc.expectFallthrough {
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
} else {
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
}
if tc.expectRecord {
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
}
})
}
}
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
t.Run("direct record lookup is authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.True(t, resp.Authoritative)
})
t.Run("external resolution is not authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2)
assert.False(t, resp.Authoritative)
})
}
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Len(t, resp.Answer, 0)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
resolver.Stop()
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})
lookupCtxCanceled := make(chan struct{})
resolver.resolver = &mockResolver{
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
close(lookupStarted)
<-ctx.Done()
close(lookupCtxCanceled)
return nil, ctx.Err()
},
}
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
done := make(chan struct{})
go func() {
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
close(done)
}()
<-lookupStarted
resolver.Stop()
select {
case <-lookupCtxCanceled:
case <-time.After(time.Second):
t.Fatal("external lookup context was not canceled")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("ServeDNS did not return after Stop")
}
})
}
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "EXAMPLE.COM.",
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
NonAuthoritative: true,
}})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG)
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
}
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
func BenchmarkFindZone_BestCase(b *testing.B) {
resolver := NewResolver()
// Single zone that matches immediately
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
NonAuthoritative: true,
}})
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough("example.com.")
}
}
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
func BenchmarkFindZone_WorstCase(b *testing.B) {
resolver := NewResolver()
// 100 zones that won't match
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
NonAuthoritative: true,
})
}
resolver.Update(zones)
// Query with many labels that won't match any zone
qname := "a.b.c.d.e.f.g.h.external.com."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
func BenchmarkFindZone_TypicalCase(b *testing.B) {
resolver := NewResolver()
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
resolver.Update([]nbdns.CustomZone{
{Domain: "netbird.cloud.", NonAuthoritative: false},
{Domain: "custom.local.", NonAuthoritative: true},
})
// Query for subdomain of user zone
qname := "myhost.custom.local."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver := NewResolver()
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
})
}
resolver.Update(zones)
// Query that matches zone50
qname := "host.zone50.internal."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.isInManagedZone(qname)
}
}

View File

@@ -0,0 +1,197 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
} }
} }
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil { if err != nil {
return fmt.Errorf("local handler updater: %w", err) return fmt.Errorf("local handler updater: %w", err)
} }
@@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
// register local records s.localResolver.Update(localZones)
s.localResolver.Update(localRecords)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -632,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface,
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
nbdns.RootZone, nbdns.RootZone,
@@ -659,9 +656,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
} }
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
var muxUpdates []handlerWrapper var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord var zones []nbdns.CustomZone
for _, customZone := range customZones { for _, customZone := range customZones {
if len(customZone.Records) == 0 { if len(customZone.Records) == 0 {
@@ -675,17 +672,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal, priority: PriorityLocal,
}) })
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records { for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass { if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class) log.Warnf("received an invalid class type: %s", record.Class)
continue continue
} }
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record) localRecords = append(localRecords, record)
} }
customZone.Records = localRecords
zones = append(zones, customZone)
} }
return muxUpdates, localRecords, nil return muxUpdates, zones, nil
} }
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) { func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
@@ -741,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface,
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
domainGroup.domain, domainGroup.domain,
@@ -924,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface,
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
nbdns.RootZone, nbdns.RootZone,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil return configurer.WGStats{}, nil
} }
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
{ {
Name: "peera.netbird.cloud", Name: "peera.netbird.cloud",
@@ -128,7 +133,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
initUpstreamMap registeredHandlerMap initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord initLocalZones []nbdns.CustomZone
initSerial uint64 initSerial uint64
inputSerial uint64 inputSerial uint64
inputUpdate nbdns.Config inputUpdate nbdns.Config
@@ -180,8 +185,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 2, initSerial: 2,
inputSerial: 1, inputSerial: 1,
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Empty NS Group Domain Or Not Primary Element Should Fail", name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid NS Group Nameservers list Should Fail", name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid Custom Zone Records list Should Skip", name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -299,8 +304,8 @@ func TestUpdateDNSServer(t *testing.T) {
}}, }},
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
@@ -315,8 +320,8 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{}, expectedLocalQs: []dns.Question{},
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
@@ -385,7 +390,7 @@ func TestUpdateDNSServer(t *testing.T) {
}() }()
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords) dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,8 +515,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
dnsServer.updateSerial = 0 dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
@@ -2048,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) { func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly // Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route") assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream") assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default") assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -2,7 +2,6 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -19,8 +18,10 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -113,10 +114,7 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID() logger := log.WithField("request_id", resutil.GetRequestID(w))
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r) u.prepareRequest(r)
@@ -202,11 +200,18 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1) u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
if err := w.WriteMsg(rm); err != nil { if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
} }
return true return true
} }
@@ -414,16 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string { // ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
bytes := make([]byte, 4) // This is needed when netstack is enabled to reach peer IPs through the tunnel.
_, err := rand.Read(bytes) func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil { if err != nil {
log.Errorf("failed to generate request ID: %v", err) return nil, err
return ""
} }
return hex.EncodeToString(bytes)
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
} }
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts // FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string { func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -23,9 +23,7 @@ type upstreamResolver struct {
// first time, and we need to wait for a while to start to use again the proper DNS resolver. // first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ WGIface,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain string,

View File

@@ -5,22 +5,23 @@ package dns
import ( import (
"context" "context"
"net/netip" "net/netip"
"runtime"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
) )
type upstreamResolver struct { type upstreamResolver struct {
*upstreamResolverBase *upstreamResolverBase
nsNet *netstack.Net
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, wgIface WGIface,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -28,12 +29,23 @@ func newUpstreamResolver(
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),
} }
upstreamResolverBase.upstreamClient = nonIOS upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil return nonIOS, nil
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
// Similar to iOS logic, we should determine if the DNS server is reachable directly
// or needs to go through the tunnel, and only use netstack when necessary.
// For now, only use netstack on JS platform where direct access is not possible.
if u.nsNet != nil && runtime.GOOS == "js" {
start := time.Now()
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
return reply, time.Since(start), err
}
client := &dns.Client{ client := &dns.Client{
Timeout: ClientTimeout, Timeout: ClientTimeout,
} }

View File

@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, wgIface WGIface,
ip netip.Addr,
net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -37,9 +35,9 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
lIP: ip, lIP: wgIface.Address().IP,
lNet: net, lNet: wgIface.Address().Network,
interfaceName: interfaceName, interfaceName: wgIface.Name(),
} }
ios.upstreamClient = ios ios.upstreamClient = ios

View File

@@ -2,13 +2,17 @@ package dns
import ( import (
"context" "context"
"net"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
) )
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
// Convert test servers to netip.AddrPort // Convert test servers to netip.AddrPort
var servers []netip.AddrPort var servers []netip.AddrPort
for _, server := range testCase.InputServers { for _, server := range testCase.InputServers {
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
} }
} }
type mockNetstackProvider struct{}
func (m *mockNetstackProvider) Name() string { return "mock" }
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct { type mockUpstreamResolver struct {
r *dns.Msg r *dns.Msg
rtt time.Duration rtt time.Duration

View File

@@ -5,6 +5,8 @@ package dns
import ( import (
"net" "net"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -17,4 +19,5 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
} }

View File

@@ -1,6 +1,8 @@
package dns package dns
import ( import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -12,5 +14,6 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@@ -18,6 +18,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg { func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 { if len(query.Question) == 0 {
return nil return nil
} }
question := query.Question[0] question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, question.Qtype, question.Qclass) question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name) domain := strings.ToLower(question.Name)
resp := query.SetReply(query) resp := query.SetReply(query)
var network string network := resutil.NetworkForQtype(question.Qtype)
switch question.Qtype { if network == "" {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
resp.Rcode = dns.RcodeNotImplemented resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
} }
return nil return nil
} }
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
if mostSpecificResId == "" { if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
} }
return nil return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
f.handleDNSError(ctx, w, question, resp, domain, err) if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil return nil
} }
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
for i, ip := range ips { resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
ips[i] = ip.Unmap() f.cache.set(domain, question.Qtype, result.IPs)
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil { if resp == nil {
return return
} }
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
} }
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
return
} }
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
} }
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil { if resp == nil {
return return
} }
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
return
} }
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
} }
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
} }
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response. // handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError( func (f *DNSForwarder) handleDNSError(
ctx context.Context, ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter, w dns.ResponseWriter,
question dns.Question, question dns.Question,
resp *dns.Msg, resp *dns.Msg,
domain string, domain string,
err error, result resutil.LookupResult,
) { ) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype qType := question.Qtype
qTypeName := dns.TypeToString[qType] qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise. resp.Rcode = result.Rcode
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// NotFound: set NXDOMAIN / appropriate code via helper. // NotFound: cache negative result and respond
if dnsErr.IsNotFound { if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.cache.set(domain, question.Qtype, nil) f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return return
} }
// Upstream failed but we might have a cached answer—serve it if present. // Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok { if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 { if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil { if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr) logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
} }
return
} }
// No cache. Log with or without the server field for more context. // No cache or verification failed. Log with or without the server field for more context.
if dnsErr.Server != "" { var dnsErr *net.DNSError
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) logger.Warnf(errResolveFailed, domain, result.Err)
} }
// Write final failure response. // Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil { if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr) logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
} }
} }

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
if tt.shouldResolve { if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain") require.NotNil(t, resp, "Expected response for authorized domain")
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response // Verify response
if tt.shouldResolve { if tt.shouldResolve {
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA) query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs // Verify response contains all IPs
require.NotNil(t, resp) require.NotNil(t, resp)
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
}, },
} }
_ = forwarder.handleDNSQuery(mockWriter, query) _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer // Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written") require.NotNil(t, writtenResp, "Expected response to be written")
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{} q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{} w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1) resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1) require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1) require.Len(t, resp1.Answer, 1)
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2) _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp, "expected response to be written") require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{} q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA) q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{} w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1) resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1) require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1) require.Len(t, resp1.Answer, 1)
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA) q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2) _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp) require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA) query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
require.NotNil(t, resp) require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode) assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
}, },
} }
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions) // If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil { if resp != nil && writtenResp == nil {
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil return nil
}, },
} }
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query") assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query") assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -1251,11 +1251,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort, ForwarderPort: forwarderPort,
} }
for _, zone := range protoDNSConfig.GetCustomZones() { protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
dnsZone := nbdns.CustomZone{ dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(), Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(), SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(), NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
} }
for _, record := range zone.Records { for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{ dnsRecord := nbdns.SimpleRecord{
@@ -1743,22 +1748,26 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
} }
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
// Skip STUN/TURN probing for JS/WASM as it's not available
relayHealthy := true relayHealthy := true
for _, res := range results { if runtime.GOOS != "js" {
if res.Err != nil { var results []relay.ProbeResult
relayHealthy = false if waitForResult {
break results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
} }
e.statusRecorder.UpdateRelayStates(results)
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
} }
log.Debugf("relay health check: healthy=%t", relayHealthy)
allHealthy := signalHealthy && managementHealthy && relayHealthy allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy) log.Debugf("all health checks completed: healthy=%t", allHealthy)

View File

@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil { if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
jwtConfig := &sshserver.JWTConfig{ jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(), Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(), Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(), KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(), MaxTokenAge: protoJWT.GetMaxTokenAge(),
} }

View File

@@ -1,5 +1,4 @@
//go:build !windows //go:build !windows
// +build !windows
package internal package internal

View File

@@ -669,10 +669,17 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { // For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false return false
} }
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected { if conn.statusRelay.Get() == worker.StatusDisconnected {
return false return false

View File

@@ -14,6 +14,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -158,6 +159,7 @@ type FullStatus struct {
NSGroupStates []NSGroupState NSGroupStates []NSGroupState
NumOfForwardingRules int NumOfForwardingRules int
LazyConnectionEnabled bool LazyConnectionEnabled bool
Events []*proto.SystemEvent
} }
type StatusChangeSubscription struct { type StatusChangeSubscription struct {
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
} }
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...) fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
fullStatus.Events = d.GetEventHistory()
return fullStatus return fullStatus
} }
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
func (s *EventSubscription) Events() <-chan *proto.SystemEvent { func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
return s.events return s.events
} }
// ToProto converts FullStatus to proto.FullStatus.
func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
if err := fs.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fs.SignalState.URL
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
if err := fs.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
for _, peerState := range fs.Peers {
networks := maps.Keys(peerState.GetRoutes())
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: networks,
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fs.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fs.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
pbFullStatus.Events = fs.Events
return &pbFullStatus
}

View File

@@ -17,12 +17,13 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -37,11 +38,6 @@ type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error AddInternalDNATMapping(netip.Addr, netip.Addr) error
} }
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct { type DnsInterceptor struct {
mu sync.RWMutex mu sync.RWMutex
route *route.Route route *route.Route
@@ -51,7 +47,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedDomains domainMap interceptedDomains domainMap
wgInterface wgInterface wgInterface iface.WGIface
peerStore *peerstore.Store peerStore *peerstore.Store
firewall firewall.Manager firewall firewall.Manager
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
@@ -219,14 +215,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID() logger := log.WithFields(log.Fields{
logger := log.WithField("request_id", requestID) "request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query // pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -249,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
@@ -263,32 +253,15 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()
startTime := time.Now() reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if reply == nil {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return return
} }
var answer []dns.RR resutil.SetMeta(w, "peer", peerKey)
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil { if err := d.writeMsg(w, reply, logger); err != nil {
logger.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
} }
@@ -324,11 +297,15 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil return peerAllowedIP, nil
} }
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
if r == nil { if r == nil {
return fmt.Errorf("received nil DNS message") return fmt.Errorf("received nil DNS message")
} }
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 { if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := "" origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok { if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
@@ -350,14 +327,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
case *dns.A: case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A) addr, ok := netip.AddrFromSlice(rr.A)
if !ok { if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue continue
} }
ip = addr ip = addr
case *dns.AAAA: case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA) addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok { if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue continue
} }
ip = addr ip = addr
@@ -370,11 +347,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
} }
if len(newPrefixes) > 0 { if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
log.Errorf("failed to update domain prefixes: %v", err) logger.Errorf("failed to update domain prefixes: %v", err)
} }
d.replaceIPsInDNSResponse(r, newPrefixes) d.replaceIPsInDNSResponse(r, newPrefixes, logger)
} }
} }
@@ -386,22 +363,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
} }
// logPrefixChanges handles the logging for prefix changes // logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) { func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
if len(toAdd) > 0 { if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(), resolvedDomain.SafeString(),
originalDomain.SafeString(), originalDomain.SafeString(),
toAdd) toAdd)
} }
if len(toRemove) > 0 && !d.route.KeepRoute { if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(), resolvedDomain.SafeString(),
originalDomain.SafeString(), originalDomain.SafeString(),
toRemove) toRemove)
} }
} }
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@@ -418,9 +395,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr() realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil { if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP) logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else { } else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err) logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
} }
} }
} }
@@ -432,7 +409,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
} }
} }
d.addDNATMappings(dnatMappings) d.addDNATMappings(dnatMappings, logger)
if !d.route.KeepRoute { if !d.route.KeepRoute {
// Remove old prefixes // Remove old prefixes
@@ -448,7 +425,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
} }
} }
d.removeDNATMappings(toRemove) d.removeDNATMappings(toRemove, logger)
} }
// Update domain prefixes using resolved domain as key - store real IPs // Update domain prefixes using resolved domain as key - store real IPs
@@ -463,14 +440,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs // Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove) d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes // removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) { func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
if len(realPrefixes) == 0 { if len(realPrefixes) == 0 {
return return
} }
@@ -484,9 +461,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
realIP := prefix.Addr() realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil { if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err) logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else { } else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP) logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
} }
} }
} }
@@ -502,7 +479,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
} }
// addDNATMappings adds DNAT mappings to the firewall // addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) { func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
if len(mappings) == 0 { if len(mappings) == 0 {
return return
} }
@@ -514,9 +491,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
for fakeIP, realIP := range mappings { for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil { if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err) logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else { } else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP) logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
} }
} }
} }
@@ -528,12 +505,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
} }
for _, prefixes := range d.interceptedDomains { for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes) d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
} }
} }
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response // replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) { func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
if _, ok := d.internalDnatFw(); !ok { if _, ok := d.internalDnatFw(); !ok {
return return
} }
@@ -549,7 +526,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice() rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
} }
case *dns.AAAA: case *dns.AAAA:
@@ -560,7 +537,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice() rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
} }
} }
} }
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
return return
} }
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
// Returns the DNS reply on success, or nil on error (error responses are written internally).
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
startTime := time.Now()
nsNet := d.wgInterface.GetNet()
var reply *dns.Msg
var err error
if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil
}
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
}
if err == nil {
return reply
}
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil { if d.statusRecorder == nil {
return "" return ""

View File

@@ -1,5 +1,4 @@
//go:build !windows //go:build !windows
// +build !windows
package iface package iface

View File

@@ -4,6 +4,8 @@ import (
"net" "net"
"net/netip" "net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
} }

View File

@@ -210,7 +210,8 @@ func (r *SysOps) refreshLocalSubnetsCache() {
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf} nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 { switch prefix {
case vars.Defaultv4:
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil { if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err return err
} }
@@ -233,7 +234,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
} }
return nil return nil
} else if prefix == vars.Defaultv6 { case vars.Defaultv6:
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil { if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err) return fmt.Errorf("add unreachable route split 1: %w", err)
} }
@@ -255,7 +256,8 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf} nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 { switch prefix {
case vars.Defaultv4:
var result *multierror.Error var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil { if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err) result = multierror.Append(result, err)
@@ -273,7 +275,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 { case vars.Defaultv6:
var result *multierror.Error var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil { if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err) result = multierror.Append(result, err)
@@ -283,9 +285,9 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
default:
return r.removeFromRouteTable(prefix, nextHop)
} }
return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {

View File

@@ -76,7 +76,7 @@ type Client struct {
loginComplete bool loginComplete bool
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked) // preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config preloadedConfig *profilemanager.Config
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client

View File

@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level) log.SetLevel(level)
if s.connectClient == nil { if s.connectClient != nil {
return nil, fmt.Errorf("connect client not initialized") s.connectClient.SetLogLevel(level)
} }
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String()) log.Infof("Log level set to %s", level.String())

View File

@@ -1,8 +1,6 @@
package server package server
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
} }
} }
} }
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
events := s.statusRecorder.GetEventHistory()
return &proto.GetEventsResponse{Events: events}, nil
}

View File

@@ -1,5 +1,4 @@
//go:build windows //go:build windows
// +build windows
package server package server

View File

@@ -13,15 +13,12 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes) s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus() fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory() pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState() pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus statusResponse.FullStatus = pbFullStatus
} }
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration return defaultDuration
} }
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message // sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired. // to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error { func sendTerminalNotification() error {

View File

@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
HostKeyPEM: hostKey, HostKeyPEM: hostKey,
JWT: &server.JWTConfig{ JWT: &server.JWTConfig{
Issuer: issuer, Issuer: issuer,
Audience: audience, Audiences: []string{audience},
KeysLocation: jwksURL, KeysLocation: jwksURL,
}, },
} }

View File

@@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
t.Run("blocks_without_jwt", func(t *testing.T) { t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{ jwtConfig := &JWTConfig{
Issuer: "test-issuer", Issuer: "test-issuer",
Audience: "test-audience", Audiences: []string{"test-audience"},
KeysLocation: "test-keys", KeysLocation: "test-keys",
} }
serverConfig := &Config{ serverConfig := &Config{
@@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {
jwtConfig := &JWTConfig{ jwtConfig := &JWTConfig{
Issuer: issuer, Issuer: issuer,
Audience: audience, Audiences: []string{audience},
KeysLocation: jwksURL, KeysLocation: jwksURL,
} }
serverConfig := &Config{ serverConfig := &Config{
@@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{ jwtConfig := &JWTConfig{
Issuer: issuer, Issuer: issuer,
Audience: audience, Audiences: []string{audience},
KeysLocation: jwksURL, KeysLocation: jwksURL,
MaxTokenAge: 3600, MaxTokenAge: 3600,
} }
@@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {
jwtConfig := &JWTConfig{ jwtConfig := &JWTConfig{
Issuer: issuer, Issuer: issuer,
Audience: audience, Audiences: []string{audience},
KeysLocation: jwksURL, KeysLocation: jwksURL,
} }
serverConfig := &Config{ serverConfig := &Config{
@@ -602,12 +602,13 @@ func TestJWTAuthentication(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var authMethods []cryptossh.AuthMethod var authMethods []cryptossh.AuthMethod
if tc.token == "valid" { switch tc.token {
case "valid":
token := generateValidJWT(t, privateKey, issuer, audience) token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{ authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token), cryptossh.Password(token),
} }
} else if tc.token == "invalid" { case "invalid":
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid" invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{ authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken), cryptossh.Password(invalidToken),
@@ -645,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
}) })
} }
} }
// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
func TestJWTMultipleAudiences(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT multiple audiences tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
dashboardAudience = "dashboard-audience"
cliAudience = "cli-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
audience string
wantAuthOK bool
}{
{
name: "accepts_dashboard_audience",
audience: dashboardAudience,
wantAuthOK: true,
},
{
name: "accepts_cli_audience",
audience: cliAudience,
wantAuthOK: true,
},
{
name: "rejects_unknown_audience",
audience: "unknown-audience",
wantAuthOK: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audiences: []string{dashboardAudience, cliAudience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0},
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := generateValidJWT(t, privateKey, issuer, tc.audience)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(token),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
err = session.Shell()
require.NoError(t, err, "Shell should work with valid audience")
} else {
assert.Error(t, err, "JWT authentication should fail for unknown audience")
}
})
}
}

View File

@@ -176,9 +176,9 @@ type Server struct {
type JWTConfig struct { type JWTConfig struct {
Issuer string Issuer string
Audience string
KeysLocation string KeysLocation string
MaxTokenAge int64 MaxTokenAge int64
Audiences []string
} }
// Config contains all SSH server configuration options // Config contains all SSH server configuration options
@@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
return fmt.Errorf("JWT config not set") return fmt.Errorf("JWT config not set")
} }
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience) if len(config.Audiences) == 0 {
return fmt.Errorf("JWT config has no audiences configured")
}
log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
validator := jwt.NewValidator( validator := jwt.NewValidator(
config.Issuer, config.Issuer,
[]string{config.Audience}, config.Audiences,
config.KeysLocation, config.KeysLocation,
true, true,
) )
// Use custom userIDClaim from authorizer if available // Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{ extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience), jwt.WithAudience(config.Audiences[0]),
} }
if authorizer.GetUserIDClaim() != "" { if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim())) extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
@@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
if err != nil { if err != nil {
if jwtConfig != nil { if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil { if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w", return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err) jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
} }
} }
return nil, fmt.Errorf("validate token: %w", err) return nil, fmt.Errorf("validate token: %w", err)

View File

@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
} }
} }
func ParseToJSON(overview OutputOverview) (string, error) { // JSON returns the status overview as a JSON string.
jsonBytes, err := json.Marshal(overview) func (o *OutputOverview) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil { if err != nil {
return "", fmt.Errorf("json marshal failed") return "", fmt.Errorf("json marshal failed")
} }
return string(jsonBytes), err return string(jsonBytes), err
} }
func ParseToYAML(overview OutputOverview) (string, error) { // YAML returns the status overview as a YAML string.
yamlBytes, err := yaml.Marshal(overview) func (o *OutputOverview) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil { if err != nil {
return "", fmt.Errorf("yaml marshal failed") return "", fmt.Errorf("yaml marshal failed")
} }
return string(yamlBytes), nil return string(yamlBytes), nil
} }
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { // GeneralSummary returns a general summary of the status overview.
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string var managementConnString string
if overview.ManagementState.Connected { if o.ManagementState.Connected {
managementConnString = "Connected" managementConnString = "Connected"
if showURL { if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL) managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
} }
} else { } else {
managementConnString = "Disconnected" managementConnString = "Disconnected"
if overview.ManagementState.Error != "" { if o.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error) managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
} }
} }
var signalConnString string var signalConnString string
if overview.SignalState.Connected { if o.SignalState.Connected {
signalConnString = "Connected" signalConnString = "Connected"
if showURL { if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL) signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
} }
} else { } else {
signalConnString = "Disconnected" signalConnString = "Disconnected"
if overview.SignalState.Error != "" { if o.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error) signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
} }
} }
interfaceTypeString := "Userspace" interfaceTypeString := "Userspace"
interfaceIP := overview.IP interfaceIP := o.IP
if overview.KernelInterface { if o.KernelInterface {
interfaceTypeString = "Kernel" interfaceTypeString = "Kernel"
} else if overview.IP == "" { } else if o.IP == "" {
interfaceTypeString = "N/A" interfaceTypeString = "N/A"
interfaceIP = "N/A" interfaceIP = "N/A"
} }
var relaysString string var relaysString string
if showRelays { if showRelays {
for _, relay := range overview.Relays.Details { for _, relay := range o.Relays.Details {
available := "Available" available := "Available"
reason := "" reason := ""
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
} }
} else { } else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
} }
networks := "-" networks := "-"
if len(overview.Networks) > 0 { if len(o.Networks) > 0 {
sort.Strings(overview.Networks) sort.Strings(o.Networks)
networks = strings.Join(overview.Networks, ", ") networks = strings.Join(o.Networks, ", ")
} }
var dnsServersString string var dnsServersString string
if showNameServers { if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups { for _, nsServerGroup := range o.NSServerGroups {
enabled := "Available" enabled := "Available"
if !nsServerGroup.Enabled { if !nsServerGroup.Enabled {
enabled = "Unavailable" enabled = "Unavailable"
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
) )
} }
} else { } else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups)) dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
} }
rosenpassEnabledStatus := "false" rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled { if o.RosenpassEnabled {
rosenpassEnabledStatus = "true" rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive { if o.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
} }
} }
lazyConnectionEnabledStatus := "false" lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled { if o.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true" lazyConnectionEnabledStatus = "true"
} }
sshServerStatus := "Disabled" sshServerStatus := "Disabled"
if overview.SSHServerState.Enabled { if o.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions) sessionCount := len(o.SSHServerState.Sessions)
if sessionCount > 0 { if sessionCount > 0 {
sessionWord := "session" sessionWord := "session"
if sessionCount > 1 { if sessionCount > 1 {
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
} }
if showSSHSessions && sessionCount > 0 { if showSSHSessions && sessionCount > 0 {
for _, session := range overview.SSHServerState.Sessions { for _, session := range o.SSHServerState.Sessions {
var sessionDisplay string var sessionDisplay string
if session.JWTUsername != "" { if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s", sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
} }
} }
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
goos := runtime.GOOS goos := runtime.GOOS
goarch := runtime.GOARCH goarch := runtime.GOARCH
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Forwarding rules: %d\n"+ "Forwarding rules: %d\n"+
"Peers count: %s\n", "Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm), fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion, o.DaemonVersion,
version.NetbirdVersion(), version.NetbirdVersion(),
overview.ProfileName, o.ProfileName,
managementConnString, managementConnString,
signalConnString, signalConnString,
relaysString, relaysString,
dnsServersString, dnsServersString,
domain.Domain(overview.FQDN).SafeString(), domain.Domain(o.FQDN).SafeString(),
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
lazyConnectionEnabledStatus, lazyConnectionEnabledStatus,
sshServerStatus, sshServerStatus,
networks, networks,
overview.NumberOfForwardingRules, o.NumberOfForwardingRules,
peersCountString, peersCountString,
) )
return summary return summary
} }
func ParseToFullDetailSummary(overview OutputOverview) string { // FullDetailSummary returns a full detailed summary with peer details and events.
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) func (o *OutputOverview) FullDetailSummary() string {
parsedEventsString := parseEvents(overview.Events) parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
summary := ParseGeneralSummary(overview, true, true, true, true) parsedEventsString := parseEvents(o.Events)
summary := o.GeneralSummary(true, true, true, true)
return fmt.Sprintf( return fmt.Sprintf(
"Peers detail:"+ "Peers detail:"+

View File

@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
} }
func TestParsingToJSON(t *testing.T) { func TestParsingToJSON(t *testing.T) {
jsonString, _ := ParseToJSON(overview) jsonString, _ := overview.JSON()
//@formatter:off //@formatter:off
expectedJSONString := ` expectedJSONString := `
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
} }
func TestParsingToYAML(t *testing.T) { func TestParsingToYAML(t *testing.T) {
yaml, _ := ParseToYAML(overview) yaml, _ := overview.YAML()
expectedYAML := expectedYAML :=
`peers: `peers:
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate) lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake) lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := ParseToFullDetailSummary(overview) detail := overview.FullDetailSummary()
expectedDetail := fmt.Sprintf( expectedDetail := fmt.Sprintf(
`Peers detail: `Peers detail:
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
} }
func TestParsingToShortVersion(t *testing.T) { func TestParsingToShortVersion(t *testing.T) {
shortVersion := ParseGeneralSummary(overview, false, false, false, false) shortVersion := overview.GeneralSummary(false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1 Daemon version: 0.14.1

View File

@@ -0,0 +1,22 @@
package system
// DiskEncryptionVolume represents encryption status of a single volume.
type DiskEncryptionVolume struct {
Path string
Encrypted bool
}
// DiskEncryptionInfo holds disk encryption detection results.
type DiskEncryptionInfo struct {
Volumes []DiskEncryptionVolume
}
// IsEncrypted returns true if the volume at the given path is encrypted.
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
for _, v := range d.Volumes {
if v.Path == path {
return v.Encrypted
}
}
return false
}

View File

@@ -0,0 +1,35 @@
//go:build darwin && !ios
package system
import (
"context"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// detectDiskEncryption detects FileVault encryption status on macOS.
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
info := DiskEncryptionInfo{}
cmdCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "fdesetup", "status")
output, err := cmd.Output()
if err != nil {
log.Debugf("execute fdesetup: %v", err)
return info
}
encrypted := strings.Contains(string(output), "FileVault is On")
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
Path: "/",
Encrypted: encrypted,
})
return info
}

View File

@@ -0,0 +1,98 @@
//go:build linux && !android
package system
import (
"bufio"
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
// detectDiskEncryption detects LUKS encryption status on Linux by reading sysfs.
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
info := DiskEncryptionInfo{}
encryptedDevices := findEncryptedDevices()
mountPoints := parseMounts(encryptedDevices)
info.Volumes = mountPoints
return info
}
// findEncryptedDevices scans /sys/block for dm-crypt (LUKS) encrypted devices.
func findEncryptedDevices() map[string]bool {
encryptedDevices := make(map[string]bool)
sysBlock := "/sys/block"
entries, err := os.ReadDir(sysBlock)
if err != nil {
log.Debugf("read /sys/block: %v", err)
return encryptedDevices
}
for _, entry := range entries {
dmUuidPath := filepath.Join(sysBlock, entry.Name(), "dm", "uuid")
data, err := os.ReadFile(dmUuidPath)
if err != nil {
continue
}
uuid := strings.TrimSpace(string(data))
if strings.HasPrefix(uuid, "CRYPT-") {
dmNamePath := filepath.Join(sysBlock, entry.Name(), "dm", "name")
if nameData, err := os.ReadFile(dmNamePath); err == nil {
dmName := strings.TrimSpace(string(nameData))
encryptedDevices["/dev/mapper/"+dmName] = true
}
encryptedDevices["/dev/"+entry.Name()] = true
}
}
return encryptedDevices
}
// parseMounts reads /proc/mounts and maps devices to mount points with encryption status.
func parseMounts(encryptedDevices map[string]bool) []DiskEncryptionVolume {
var volumes []DiskEncryptionVolume
mountsFile, err := os.Open("/proc/mounts")
if err != nil {
log.Debugf("open /proc/mounts: %v", err)
return volumes
}
defer func() {
if err := mountsFile.Close(); err != nil {
log.Debugf("close /proc/mounts: %v", err)
}
}()
scanner := bufio.NewScanner(mountsFile)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) < 2 {
continue
}
device, mountPoint := fields[0], fields[1]
encrypted := encryptedDevices[device]
if !encrypted && strings.HasPrefix(device, "/dev/mapper/") {
for encDev := range encryptedDevices {
if device == encDev {
encrypted = true
break
}
}
}
volumes = append(volumes, DiskEncryptionVolume{
Path: mountPoint,
Encrypted: encrypted,
})
}
return volumes
}

View File

@@ -0,0 +1,10 @@
//go:build android || ios || freebsd || js
package system
import "context"
// detectDiskEncryption is a stub for unsupported platforms.
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
return DiskEncryptionInfo{}
}

View File

@@ -0,0 +1,41 @@
//go:build windows
package system
import (
"context"
"strings"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
)
// Win32EncryptableVolume represents the WMI class for BitLocker status.
type Win32EncryptableVolume struct {
DriveLetter string
ProtectionStatus uint32
}
// detectDiskEncryption detects BitLocker encryption status on Windows via WMI.
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
info := DiskEncryptionInfo{}
var volumes []Win32EncryptableVolume
query := "SELECT DriveLetter, ProtectionStatus FROM Win32_EncryptableVolume"
err := wmi.QueryNamespace(query, &volumes, `root\CIMV2\Security\MicrosoftVolumeEncryption`)
if err != nil {
log.Debugf("query BitLocker status: %v", err)
return info
}
for _, vol := range volumes {
driveLetter := strings.TrimSuffix(vol.DriveLetter, "\\")
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
Path: driveLetter,
Encrypted: vol.ProtectionStatus == 1,
})
}
return info
}

View File

@@ -59,6 +59,7 @@ type Info struct {
SystemManufacturer string SystemManufacturer string
Environment Environment Environment Environment
Files []File // for posture checks Files []File // for posture checks
DiskEncryption DiskEncryptionInfo
RosenpassEnabled bool RosenpassEnabled bool
RosenpassPermissive bool RosenpassPermissive bool

View File

@@ -1,6 +1,3 @@
//go:build android
// +build android
package system package system
import ( import (
@@ -47,6 +44,7 @@ func GetInfo(ctx context.Context) *Info {
SystemSerialNumber: serial(), SystemSerialNumber: serial(),
SystemProductName: productModel(), SystemProductName: productModel(),
SystemManufacturer: productManufacturer(), SystemManufacturer: productManufacturer(),
DiskEncryption: detectDiskEncryption(ctx),
} }
return gio return gio

View File

@@ -1,5 +1,4 @@
//go:build !ios //go:build !ios
// +build !ios
package system package system
@@ -63,6 +62,7 @@ func GetInfo(ctx context.Context) *Info {
SystemProductName: si.SystemProductName, SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment, Environment: si.Environment,
DiskEncryption: detectDiskEncryption(ctx),
} }
systemHostname, _ := os.Hostname() systemHostname, _ := os.Hostname()

View File

@@ -55,6 +55,7 @@ func GetInfo(ctx context.Context) *Info {
UIVersion: extractUserAgent(ctx), UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1], KernelVersion: osInfo[1],
Environment: env, Environment: env,
DiskEncryption: detectDiskEncryption(ctx),
} }
} }

View File

@@ -1,6 +1,3 @@
//go:build ios
// +build ios
package system package system
import ( import (
@@ -22,7 +19,7 @@ func GetInfo(ctx context.Context) *Info {
sysName := extractOsName(ctx, "sysName") sysName := extractOsName(ctx, "sysName")
swVersion := extractOsVersion(ctx, "swVersion") swVersion := extractOsVersion(ctx, "swVersion")
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion} gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion, DiskEncryption: detectDiskEncryption(ctx)}
gio.Hostname = extractDeviceName(ctx, "hostname") gio.Hostname = extractDeviceName(ctx, "hostname")
gio.NetbirdVersion = version.NetbirdVersion() gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx) gio.UIVersion = extractUserAgent(ctx)

View File

@@ -15,7 +15,7 @@ func UpdateStaticInfoAsync() {
} }
// GetInfo retrieves system information for WASM environment // GetInfo retrieves system information for WASM environment
func GetInfo(_ context.Context) *Info { func GetInfo(ctx context.Context) *Info {
info := &Info{ info := &Info{
GoOS: runtime.GOOS, GoOS: runtime.GOOS,
Kernel: runtime.GOARCH, Kernel: runtime.GOARCH,
@@ -25,6 +25,7 @@ func GetInfo(_ context.Context) *Info {
Hostname: "wasm-client", Hostname: "wasm-client",
CPUs: runtime.NumCPU(), CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(), NetbirdVersion: version.NetbirdVersion(),
DiskEncryption: detectDiskEncryption(ctx),
} }
collectBrowserInfo(info) collectBrowserInfo(info)

View File

@@ -73,6 +73,7 @@ func GetInfo(ctx context.Context) *Info {
SystemProductName: si.SystemProductName, SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment, Environment: si.Environment,
DiskEncryption: detectDiskEncryption(ctx),
} }
return gio return gio

View File

@@ -35,6 +35,7 @@ func GetInfo(ctx context.Context) *Info {
SystemProductName: si.SystemProductName, SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer, SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment, Environment: si.Environment,
DiskEncryption: detectDiskEncryption(ctx),
} }
addrs, err := networkAddresses() addrs, err := networkAddresses()

View File

@@ -510,7 +510,7 @@ func (s *serviceClient) saveSettings() {
// Continue with default behavior if features can't be retrieved // Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings { } else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon") log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) dialog.ShowError(fmt.Errorf("configuration updates are disabled by daemon"), s.wSettings)
return return
} }
@@ -540,7 +540,7 @@ func (s *serviceClient) saveSettings() {
func (s *serviceClient) validateSettings() error { func (s *serviceClient) validateSettings() error {
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
return fmt.Errorf("Invalid Pre-shared Key Value") return fmt.Errorf("invalid pre-shared key value")
} }
} }
return nil return nil
@@ -549,10 +549,10 @@ func (s *serviceClient) validateSettings() error {
func (s *serviceClient) parseNumericSettings() (int64, int64, error) { func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
if err != nil { if err != nil {
return 0, 0, errors.New("Invalid interface port") return 0, 0, errors.New("invalid interface port")
} }
if port < 1 || port > 65535 { if port < 1 || port > 65535 {
return 0, 0, errors.New("Invalid interface port: out of range 1-65535") return 0, 0, errors.New("invalid interface port: out of range 1-65535")
} }
var mtu int64 var mtu int64
@@ -560,7 +560,7 @@ func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
if mtuText != "" { if mtuText != "" {
mtu, err = strconv.ParseInt(mtuText, 10, 64) mtu, err = strconv.ParseInt(mtuText, 10, 64)
if err != nil { if err != nil {
return 0, 0, errors.New("Invalid MTU value") return 0, 0, errors.New("invalid MTU value")
} }
if mtu < iface.MinMTU || mtu > iface.MaxMTU { if mtu < iface.MinMTU || mtu > iface.MaxMTU {
return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU) return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
if sshJWTCacheTTLText != "" { if sshJWTCacheTTLText != "" {
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32) sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
if err != nil { if err != nil {
return nil, errors.New("Invalid SSH JWT Cache TTL value") return nil, errors.New("invalid SSH JWT Cache TTL value")
} }
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL { if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL) return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)

View File

@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = overview.FullDetailSummary()
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = overview.FullDetailSummary()
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration) time.Now().Format(time.RFC3339), params.duration)
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = overview.FullDetailSummary()
} }
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{

View File

@@ -164,7 +164,7 @@ func sendShowWindowSignal(pid int32) error {
err = windows.SetEvent(eventHandle) err = windows.SetEvent(eventHandle)
if err != nil { if err != nil {
return fmt.Errorf("Error setting event: %w", err) return fmt.Errorf("error setting event: %w", err)
} }
return nil return nil

View File

@@ -9,20 +9,29 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
netbird "github.com/netbirdio/netbird/client/embed" netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/proto"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection" sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
const ( const (
clientStartTimeout = 30 * time.Second clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second clientStopTimeout = 10 * time.Second
pingTimeout = 10 * time.Second
defaultLogLevel = "warn" defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second defaultSSHDetectionTimeout = 20 * time.Second
icmpEchoRequest = 8
icmpCodeEcho = 0
pingBufferSize = 1500
) )
func main() { func main() {
@@ -113,18 +122,45 @@ func createStopMethod(client *netbird.Client) js.Func {
}) })
} }
// validateSSHArgs validates SSH connection arguments
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
if len(args) < 2 {
return "", 0, "", js.ValueOf("error: requires host and port")
}
if args[0].Type() != js.TypeString {
return "", 0, "", js.ValueOf("host parameter must be a string")
}
if args[1].Type() != js.TypeNumber {
return "", 0, "", js.ValueOf("port parameter must be a number")
}
host = args[0].String()
port = args[1].Int()
username = "root"
if len(args) > 2 {
if args[2].Type() == js.TypeString && args[2].String() != "" {
username = args[2].String()
} else if args[2].Type() != js.TypeString {
return "", 0, "", js.ValueOf("username parameter must be a string")
}
}
return host, port, username, js.Undefined()
}
// createSSHMethod creates the SSH connection method // createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func { func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any { return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 { host, port, username, validationErr := validateSSHArgs(args)
return js.ValueOf("error: requires host and port") if !validationErr.IsUndefined() {
} if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
return validationErr
host := args[0].String() }
port := args[1].Int() return createPromise(func(resolve, reject js.Value) {
username := "root" reject.Invoke(validationErr)
if len(args) > 2 && args[2].String() != "" { })
username = args[2].String()
} }
var jwtToken string var jwtToken string
@@ -154,6 +190,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
}) })
} }
func performPing(client *netbird.Client, hostname string) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
start := time.Now()
conn, err := client.Dial(ctx, "ping", hostname)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
return
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close ping connection: %v", err)
}
}()
icmpData := make([]byte, 8)
icmpData[0] = icmpEchoRequest
icmpData[1] = icmpCodeEcho
if _, err := conn.Write(icmpData); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
return
}
buf := make([]byte, pingBufferSize)
if _, err := conn.Read(buf); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
return
}
latency := time.Since(start)
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
}
func performPingTCP(client *netbird.Client, hostname string, port int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
address := fmt.Sprintf("%s:%d", hostname, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
return
}
latency := time.Since(start)
if err := conn.Close(); err != nil {
log.Debugf("failed to close TCP connection: %v", err)
}
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
}
// createPingMethod creates the ping method
func createPingMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: hostname required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
hostname := args[0].String()
return createPromise(func(resolve, reject js.Value) {
performPing(client, hostname)
resolve.Invoke(js.Undefined())
})
})
}
// createPingTCPMethod creates the pingtcp method
func createPingTCPMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeNumber {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a number"))
})
}
hostname := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
performPingTCP(client, hostname, port)
resolve.Invoke(js.Undefined())
})
})
}
// createProxyRequestMethod creates the proxyRequest method // createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func { func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any { return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -162,6 +302,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
} }
request := args[0] request := args[0]
if request.Type() != js.TypeObject {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("request parameter must be an object"))
})
}
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request) response, err := http.ProxyRequest(client, request)
@@ -181,11 +326,145 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.ValueOf("error: hostname and port required") return js.ValueOf("error: hostname and port required")
} }
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a string"))
})
}
proxy := rdp.NewRDCleanPathProxy(client) proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String()) return proxy.CreateProxy(args[0].String(), args[1].String())
}) })
} }
// getStatusOverview is a helper to get the status overview
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
fullStatus, err := client.Status()
if err != nil {
return nbstatus.OutputOverview{}, err
}
pbFullStatus := fullStatus.ToProto()
statusResp := &proto.StatusResponse{
DaemonVersion: version.NetbirdVersion(),
FullStatus: pbFullStatus,
}
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
}
// createStatusMethod creates the status method that returns JSON
func createStatusMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonStr, err := overview.JSON()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
resolve.Invoke(jsonObj)
})
})
}
// createStatusSummaryMethod creates the statusSummary method
func createStatusSummaryMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
summary := overview.GeneralSummary(false, false, false, false)
js.Global().Get("console").Call("log", summary)
resolve.Invoke(js.Undefined())
})
})
}
// createStatusDetailMethod creates the statusDetail method
func createStatusDetailMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
detail := overview.FullDetailSummary()
js.Global().Get("console").Call("log", detail)
resolve.Invoke(js.Undefined())
})
})
}
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
AllowPartial: true,
}
jsonBytes, err := options.Marshal(syncResp)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
resolve.Invoke(jsonObj)
})
})
}
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
func createSetLogLevelMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: log level required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("log level parameter must be a string"))
})
}
logLevel := args[0].String()
return createPromise(func(resolve, reject js.Value) {
if err := client.SetLogLevel(logLevel); err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
return
}
log.Infof("Log level set to: %s", logLevel)
resolve.Invoke(js.ValueOf(true))
})
})
}
// createPromise is a helper to create JavaScript promises // createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value { func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
@@ -237,17 +516,24 @@ func createClientObject(client *netbird.Client) js.Value {
obj["start"] = createStartMethod(client) obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client) obj["stop"] = createStopMethod(client)
obj["ping"] = createPingMethod(client)
obj["pingtcp"] = createPingTCPMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client) obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client) obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client)
return js.ValueOf(obj) return js.ValueOf(obj)
} }
// netBirdClientConstructor acts as a JavaScript constructor function // netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(this js.Value, args []js.Value) any { func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0] resolve := promiseArgs[0]
reject := promiseArgs[1] reject := promiseArgs[1]

View File

@@ -47,8 +47,8 @@ type CustomZone struct {
Records []SimpleRecord Records []SimpleRecord
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not // SearchDomainDisabled indicates whether to add match domains to a search domains list or not
SearchDomainDisabled bool SearchDomainDisabled bool
// SkipPTRProcess indicates whether a client should process PTR records from custom zones // NonAuthoritative marks user-created zones
SkipPTRProcess bool NonAuthoritative bool
} }
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records // SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records

24
go.mod
View File

@@ -1,6 +1,8 @@
module github.com/netbirdio/netbird module github.com/netbirdio/netbird
go 1.24.10 go 1.25
toolchain go1.25.5
require ( require (
cunicu.li/go-rosenpass v0.4.0 cunicu.li/go-rosenpass v0.4.0
@@ -40,7 +42,7 @@ require (
github.com/cilium/ebpf v0.15.0 github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.13 github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.24
github.com/dexidp/dex v0.0.0-00010101000000-000000000000 github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex/api/v2 v2.4.0 github.com/dexidp/dex/api/v2 v2.4.0
github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/lib/v4 v4.2.0
@@ -76,12 +78,12 @@ require (
github.com/pion/logging v0.2.4 github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0 github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0 github.com/pion/stun/v2 v2.0.0
github.com/pion/stun/v3 v3.0.0 github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.0.7 github.com/pion/transport/v3 v3.1.1
github.com/pion/turn/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9 github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.49.1 github.com/quic-go/quic-go v0.55.0
github.com/redis/go-redis/v9 v9.7.3 github.com/redis/go-redis/v9 v9.7.3
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4 github.com/shirou/gopsutil/v3 v3.24.4
@@ -103,7 +105,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.38.0 go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0
go.uber.org/mock v0.5.0 go.uber.org/mock v0.5.2
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
@@ -120,7 +122,7 @@ require (
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.25.12 gorm.io/gorm v1.25.12
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c
) )
require ( require (
@@ -186,12 +188,10 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/go-text/render v0.2.0 // indirect github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect github.com/go-text/typesetting v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect
github.com/google/s2a-go v0.1.9 // indirect github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect
@@ -241,7 +241,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.7 // indirect github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect github.com/pion/turn/v4 v4.1.1 // indirect
@@ -263,7 +263,7 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect github.com/wlynxg/anet v0.0.5 // indirect
github.com/yuin/goldmark v1.7.8 // indirect github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect
@@ -285,7 +285,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

41
go.sum
View File

@@ -101,9 +101,6 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
@@ -121,8 +118,8 @@ github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmr
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw= github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -286,7 +283,6 @@ github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -411,8 +407,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -448,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -459,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
@@ -491,8 +487,8 @@ github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9Z
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= github.com/quic-go/quic-go v0.55.0 h1:zccPQIqYCXDt5NmcEabyYvOnomjs8Tlwl7tISjJh9Mk=
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/quic-go/quic-go v0.55.0/go.mod h1:DR51ilwU1uE164KuWXhinFcKWGlEjzys2l8zUl5Ss1U=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
@@ -578,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -622,8 +618,8 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
@@ -717,7 +713,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -848,5 +843,5 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c h1:pfzmXIkkDgydR4ZRP+e1hXywZfYR21FA0Fbk6ptMkiA=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= gvisor.dev/gvisor v0.0.0-20251031020517-ecfcdd2f171c/go.mod h1:/mc6CfwbOm5KKmqoV7Qx20Q+Ja8+vO4g7FuCdlVoAfQ=

113
idp/dex/logrus_handler.go Normal file
View File

@@ -0,0 +1,113 @@
package dex
import (
"context"
"log/slog"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/formatter"
)
// LogrusHandler is an slog.Handler that delegates to logrus.
// This allows Dex to use the same log format as the rest of NetBird.
type LogrusHandler struct {
logger *logrus.Logger
attrs []slog.Attr
groups []string
}
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
func NewLogrusHandler(level slog.Level) *LogrusHandler {
logger := logrus.New()
formatter.SetTextFormatter(logger)
// Map slog level to logrus level
switch level {
case slog.LevelDebug:
logger.SetLevel(logrus.DebugLevel)
case slog.LevelInfo:
logger.SetLevel(logrus.InfoLevel)
case slog.LevelWarn:
logger.SetLevel(logrus.WarnLevel)
case slog.LevelError:
logger.SetLevel(logrus.ErrorLevel)
default:
logger.SetLevel(logrus.WarnLevel)
}
return &LogrusHandler{logger: logger}
}
// Enabled reports whether the handler handles records at the given level.
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
switch level {
case slog.LevelDebug:
return h.logger.IsLevelEnabled(logrus.DebugLevel)
case slog.LevelInfo:
return h.logger.IsLevelEnabled(logrus.InfoLevel)
case slog.LevelWarn:
return h.logger.IsLevelEnabled(logrus.WarnLevel)
case slog.LevelError:
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
default:
return true
}
}
// Handle handles the Record.
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
fields := make(logrus.Fields)
// Add pre-set attributes
for _, attr := range h.attrs {
fields[attr.Key] = attr.Value.Any()
}
// Add record attributes
r.Attrs(func(attr slog.Attr) bool {
fields[attr.Key] = attr.Value.Any()
return true
})
entry := h.logger.WithFields(fields)
switch r.Level {
case slog.LevelDebug:
entry.Debug(r.Message)
case slog.LevelInfo:
entry.Info(r.Message)
case slog.LevelWarn:
entry.Warn(r.Message)
case slog.LevelError:
entry.Error(r.Message)
default:
entry.Info(r.Message)
}
return nil
}
// WithAttrs returns a new Handler with the given attributes added.
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
copy(newAttrs, h.attrs)
copy(newAttrs[len(h.attrs):], attrs)
return &LogrusHandler{
logger: h.logger,
attrs: newAttrs,
groups: h.groups,
}
}
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
newGroups := make([]string, len(h.groups)+1)
copy(newGroups, h.groups)
newGroups[len(h.groups)] = name
return &LogrusHandler{
logger: h.logger,
attrs: h.attrs,
groups: newGroups,
}
}

View File

@@ -130,7 +130,21 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig // NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) { func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) // Configure log level from config, default to WARN to avoid logging sensitive data (emails)
logLevel := slog.LevelWarn
if yamlConfig.Logger.Level != "" {
switch strings.ToLower(yamlConfig.Logger.Level) {
case "debug":
logLevel = slog.LevelDebug
case "info":
logLevel = slog.LevelInfo
case "warn", "warning":
logLevel = slog.LevelWarn
case "error":
logLevel = slog.LevelError
}
}
logger := slog.New(NewLogrusHandler(logLevel))
stor, err := yamlConfig.Storage.OpenStorage(logger) stor, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil { if err != nil {
@@ -778,20 +792,24 @@ func (p *Provider) resolveRedirectURI(redirectURI string) string {
// buildOIDCConnectorConfig creates config for OIDC-based connectors // buildOIDCConnectorConfig creates config for OIDC-based connectors
func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) { func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte, error) {
oidcConfig := map[string]interface{}{ oidcConfig := map[string]interface{}{
"issuer": cfg.Issuer, "issuer": cfg.Issuer,
"clientID": cfg.ClientID, "clientID": cfg.ClientID,
"clientSecret": cfg.ClientSecret, "clientSecret": cfg.ClientSecret,
"redirectURI": redirectURI, "redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"}, "scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
"insecureSkipEmailVerified": true,
} }
switch cfg.Type { switch cfg.Type {
case "zitadel": case "zitadel":
oidcConfig["getUserInfo"] = true oidcConfig["getUserInfo"] = true
case "entra": case "entra":
oidcConfig["insecureSkipEmailVerified"] = true
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"} oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta": case "okta":
oidcConfig["insecureSkipEmailVerified"] = true oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
} }
return encodeConnectorConfig(oidcConfig) return encodeConnectorConfig(oidcConfig)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -64,7 +64,7 @@ var (
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
} }
tlsEnabled := false var tlsEnabled bool
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") { if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {
tlsEnabled = true tlsEnabled = true
} }
@@ -143,7 +143,7 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Confi
applyCommandLineOverrides(loadedConfig) applyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled // Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(loadedConfig) err := applyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -177,7 +177,7 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled. // applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig. // This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error { func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled { if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil return nil
} }
@@ -190,10 +190,8 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
// Enable user deletion from IDP by default if EmbeddedIdP is enabled // Enable user deletion from IDP by default if EmbeddedIdP is enabled
userDeleteFromIDPEnabled = true userDeleteFromIDPEnabled = true
// Ensure HttpConfig exists // Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
if cfg.HttpConfig == nil { cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set storage defaults based on Datadir // Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" { if cfg.EmbeddedIdP.Storage.Type == "" {
@@ -205,40 +203,22 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
issuer := cfg.EmbeddedIdP.Issuer issuer := cfg.EmbeddedIdP.Issuer
// Set AuthIssuer from EmbeddedIdP issuer if cfg.HttpConfig != nil {
if cfg.HttpConfig.AuthIssuer == "" { log.WithContext(ctx).Warnf("overriding HttpConfig with EmbeddedIdP config. " +
cfg.HttpConfig.AuthIssuer = issuer "HttpConfig is ignored when EmbeddedIdP is enabled. Please remove HttpConfig section from the config file")
} else {
// Ensure HttpConfig exists. We need it for backwards compatibility with the old config format.
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
} }
// Set AuthAudience to the dashboard client ID // Set HttpConfig values from EmbeddedIdP
if cfg.HttpConfig.AuthAudience == "" { cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard" cfg.HttpConfig.AuthAudience = "netbird-dashboard"
} cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
// Set CLIAuthAudience to the client app client ID cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
if cfg.HttpConfig.CLIAuthAudience == "" { cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli" cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
}
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
if cfg.HttpConfig.AuthUserIDClaim == "" {
cfg.HttpConfig.AuthUserIDClaim = "sub"
}
// Set AuthKeysLocation to the JWKS endpoint
if cfg.HttpConfig.AuthKeysLocation == "" {
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
}
// Set OIDCConfigEndpoint to the discovery endpoint
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
}
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
}
return nil return nil
} }
@@ -246,7 +226,12 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified // applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error { func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil { if oidcEndpoint == "" {
return nil
}
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
// skip OIDC config fetching if EmbeddedIdP is enabled as it is unnecessary given it is embedded
return nil return nil
} }

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsCache := &cache.DNSConfigCache{} dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings) dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain) peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers() groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range account.Peers { for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) { if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else { } else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} }
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsCache := &cache.DNSConfigCache{} dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings) dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain) peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers() groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err return err
} }
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) { if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else { } else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} }
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil { if err != nil {
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
var networkMap *types.NetworkMap var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) { if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else { } else {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers()) networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
accountId string, accountId string,
peerId string, peerId string,
validatedPeers map[string]struct{}, validatedPeers map[string]struct{},
customZone nbdns.CustomZone, peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics, metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap { ) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId) account := c.getAccountFromHolderOrInit(ctx, accountId)
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
} }
} }
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
} }
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) { func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if err != nil { if err != nil {
return nil, err return nil, err
} }
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil { if err != nil {
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
var networkMap *types.NetworkMap var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) { if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else { } else {
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
} }
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -3,6 +3,7 @@ package controller
import ( import (
"context" "context"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -14,6 +15,7 @@ type Repository interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
} }
type repository struct { type repository struct {
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) { func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -0,0 +1,13 @@
package zones
import (
"context"
)
type Manager interface {
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
}

Some files were not shown because too many files have changed in this diff Show More