Compare commits

..

2 Commits

Author SHA1 Message Date
crn4
97ad3307dd Merge branch 'main' into log/conn-disconn 2026-01-23 19:59:25 +01:00
crn4
f6cc27d675 add some logs for conn/disconn status change 2026-01-23 18:01:26 +01:00
405 changed files with 2689 additions and 57249 deletions

View File

@@ -1,6 +0,0 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -46,5 +46,6 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -97,16 +97,6 @@ jobs:
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
- name: Build combined
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 go build .
- name: Build combined 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: combined
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
test:
name: "Client / Unit"
needs: [build-cache]
@@ -154,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -214,7 +204,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
@@ -271,53 +261,6 @@ jobs:
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,8 +19,8 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
skip: go.mod,go.sum,**/proxy/web/**
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum
golangci:
strategy:
fail-fast: false

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.1"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -160,7 +160,7 @@ jobs:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ghcr.io
@@ -176,7 +176,6 @@ jobs:
- name: Generate windows syso arm64
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
@@ -186,19 +185,6 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: Tag and push PR images (amd64 only)
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
run: |
PR_TAG="pr-${{ github.event.pull_request.number }}"
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
IMG_NAME="${SRC%%:*}"
DST="${IMG_NAME}:${PR_TAG}"
echo "Tagging ${SRC} -> ${DST}"
docker tag "$SRC" "$DST"
docker push "$DST"
done
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
.run
*.iml
dist/
!proxy/web/dist/
bin/
.env
conf.json

View File

@@ -106,26 +106,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -140,20 +120,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-proxy
dir: proxy/cmd/proxy
env: [CGO_ENABLED=0]
binary: netbird-proxy
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -554,104 +520,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -730,18 +598,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -819,43 +675,6 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -1,4 +1,4 @@
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
This BSD3Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
BSD 3-Clause License

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features

View File

@@ -1,19 +1,10 @@
package android
import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
)
// EnvList wraps a Go map for export to Java

View File

@@ -3,7 +3,15 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -76,21 +84,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -108,17 +129,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err, _ = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -137,41 +160,49 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
go urlOpener.OnLoginSuccess()
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
}
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
if err != nil {
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
return nil, err
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -181,10 +212,22 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -7,6 +7,7 @@ import (
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -276,15 +277,18 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
needsLogin, err := authClient.IsLoginRequired(ctx)
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("check login required: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
@@ -296,9 +300,23 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
err, _ = authClient.Login(ctx, setupKey, jwtToken)
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
return nil
@@ -326,7 +344,11 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
@@ -31,14 +30,6 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -77,10 +68,6 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
}
// validateCredentials checks that exactly one credential type is provided
@@ -149,8 +136,6 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -170,7 +155,6 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil
}
@@ -192,17 +176,13 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -355,9 +335,14 @@ func (c *Client) NewHTTPClient() *http.Client {
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
@@ -365,7 +350,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
}
}
return c.recorder.GetFullStatus(), nil
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.

View File

@@ -83,10 +83,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChain(); err != nil {
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -181,10 +177,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.cleanupNoTrackChain(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err))
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
@@ -285,125 +277,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
tableRaw = "raw"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
// Egress rules: match outgoing loopback UDP packets
outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil {
return fmt.Errorf("add output sport notrack rule: %w", err)
}
outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil {
return fmt.Errorf("add output dport notrack rule: %w", err)
}
// Ingress rules: match incoming loopback UDP packets
preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil {
return fmt.Errorf("add prerouting wg notrack rule: %w", err)
}
preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"}
if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil {
return fmt.Errorf("add prerouting proxy notrack rule: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
}
if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("create chain: %w", err)
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add output jump rule: %w", err)
}
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
log.Debugf("delete output jump rule: %v", delErr)
}
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
log.Debugf("delete orphan chain: %v", delErr)
}
return fmt.Errorf("add prerouting jump rule: %w", err)
}
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
return nil
}
jumpRule := []string{"-j", chainNameRaw}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
return fmt.Errorf("remove output jump rule: %w", err)
}
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
return fmt.Errorf("remove prerouting jump rule: %w", err)
}
if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil {
return fmt.Errorf("clear and delete chain: %w", err)
}
return nil
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -168,10 +168,6 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -49,10 +48,8 @@ type Manager struct {
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
router *router
aclManager *AclManager
}
// Create nftables firewall manager
@@ -94,10 +91,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return fmt.Errorf("acl manager init: %w", err)
}
if err := m.initNoTrackChains(workTable); err != nil {
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
@@ -295,15 +288,7 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.Flush(); err != nil {
return err
}
if err := m.refreshNoTrackChains(); err != nil {
log.Errorf("failed to refresh notrack chains: %v", err)
}
return nil
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
@@ -346,176 +331,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
)
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which
// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark).
//
// Traffic flows that need NOTRACK:
//
// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite)
// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort
// Matched by: sport=wgPort
//
// 2. Egress: Proxy -> WireGuard (via raw socket)
// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 3. Ingress: Packets to WireGuard
// dst=127.0.0.1:wgPort
// Matched by: dport=wgPort
//
// 4. Ingress: Packets to proxy (after eBPF rewrite)
// dst=127.0.0.1:proxyPort
// Matched by: dport=proxyPort
//
// Rules are cleaned up when the firewall manager is closed.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil {
return fmt.Errorf("notrack chains not initialized")
}
proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort)
wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort)
loopback := []byte{127, 0, 0, 1}
// Egress rules: match outgoing loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackOutputChain.Table,
Chain: m.notrackOutputChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
// Ingress rules: match incoming loopback UDP packets
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort
&expr.Counter{},
&expr.Notrack{},
},
})
m.rConn.AddRule(&nftables.Rule{
Table: m.notrackPreroutingChain.Table,
Chain: m.notrackPreroutingChain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}},
&expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort
&expr.Counter{},
&expr.Notrack{},
},
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush notrack rules: %w", err)
}
log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort)
return nil
}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityRaw,
})
m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawPrerouting,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRaw,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush chain creation: %w", err)
}
return nil
}
func (m *Manager) refreshNoTrackChains() error {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
tableName := getTableName()
for _, c := range chains {
if c.Table.Name != tableName {
continue
}
switch c.Name {
case chainNameRawOutput:
m.notrackOutputChain = c
case chainNameRawPrerouting:
m.notrackPreroutingChain = c
}
}
return nil
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -483,12 +483,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -665,32 +660,13 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -952,30 +928,18 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
}
return nil
@@ -1365,89 +1329,65 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
return fmt.Errorf("remove legacy routing rule: %w", err)
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
}
return nil
}
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("prerouting rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
return fmt.Errorf("list rules: %w", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[string(rule.UserData)] = rule
r.rules[string(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
return nil
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1689,34 +1629,20 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
@@ -1831,25 +1757,16 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -18,7 +18,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -720,137 +719,3 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -3,6 +3,12 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -11,7 +17,33 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,9 +1,12 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -23,7 +26,33 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() {
return nil

View File

@@ -115,17 +115,6 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
@@ -180,7 +169,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists && !conn.IsSupersededBy(flags) {
if exists {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
}
@@ -252,7 +241,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsSupersededBy(flags) {
if !exists || conn.IsTombstone() {
return false
}

View File

@@ -485,261 +485,6 @@ func TestTCPAbnormalSequences(t *testing.T) {
})
}
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond

View File

@@ -1,7 +1,6 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -13,13 +12,11 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -27,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -93,7 +89,6 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -234,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -486,15 +480,11 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: string(ruleKey),
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -509,7 +499,6 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -526,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
return r.id == ruleID
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
return fmt.Errorf("route rule not found: %s", ruleID)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -586,48 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {

View File

@@ -1,376 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,158 +263,6 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -5,8 +5,6 @@ import (
"context"
"fmt"
"io"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -18,18 +16,9 @@ const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
defaultLogChanSize = 1000
logChannelSize = 1000
)
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32
const (
@@ -80,7 +69,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, getLogChannelSize()),
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {

View File

@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {

View File

@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
host, portStr, err := net.SplitHostPort(val)
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue

View File

@@ -29,9 +29,8 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceFilter constructor function
@@ -41,20 +40,6 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -82,9 +82,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
@@ -51,7 +50,6 @@ func ValidateMTU(mtu uint16) error {
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
Free() error
}
@@ -82,12 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetProxyPort returns the proxy port used by the WireGuard proxy.
// Returns 0 if no proxy port is used (e.g., for userspace WireGuard).
func (w *WGIface) GetProxyPort() uint16 {
return w.wgProxyFactory.GetProxyPort()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
@@ -229,10 +221,6 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if nbnetstack.IsEnabled() {
return errors.FormatErrorOrNil(result)
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
if err := w.Destroy(); err != nil {

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
}()
return t.tundev, tunNet, nil
return nsTunDev, tunNet, nil
}
func (t *NetStackTun) Close() error {

View File

@@ -114,21 +114,34 @@ func (p *ProxyBind) Pause() {
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = ep
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to convert endpoint address: %v", err)
} else {
p.wgCurrentUsed = ep
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, errors.New("nil address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
@@ -212,16 +225,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}

View File

@@ -8,6 +8,8 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
@@ -24,10 +26,13 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
proxyPort int
mtu uint16
ebpfManager ebpfMgr.Manager
@@ -35,8 +40,7 @@ type WGEBPFProxy struct {
turnConnMutex sync.Mutex
lastUsedPort uint16
rawConnIPv4 net.PacketConn
rawConnIPv6 net.PacketConn
rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
@@ -58,39 +62,23 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy {
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
proxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
p.proxyPort = proxyPort
// Prepare IPv4 raw socket (required)
p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4()
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
return err
}
// Prepare IPv6 raw socket (optional)
p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6()
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil {
log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err)
return err
}
err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort)
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
if err != nil {
if closeErr := p.rawConnIPv4.Close(); closeErr != nil {
log.Warnf("failed to close IPv4 raw socket: %v", closeErr)
}
if p.rawConnIPv6 != nil {
if closeErr := p.rawConnIPv6.Close(); closeErr != nil {
log.Warnf("failed to close IPv6 raw socket: %v", closeErr)
}
}
return err
}
addr := net.UDPAddr{
Port: proxyPort,
Port: wgPorxyPort,
IP: net.ParseIP(loopbackAddr),
}
@@ -106,7 +94,7 @@ func (p *WGEBPFProxy) Listen() error {
p.conn = conn
go p.proxyToRemote()
log.Infof("local wg proxy listening on: %d", proxyPort)
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
return nil
}
@@ -147,25 +135,12 @@ func (p *WGEBPFProxy) Free() error {
result = multierror.Append(result, err)
}
if p.rawConnIPv4 != nil {
if err := p.rawConnIPv4.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if p.rawConnIPv6 != nil {
if err := p.rawConnIPv6.Close(); err != nil {
result = multierror.Append(result, err)
}
if err := p.rawConn.Close(); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
// GetProxyPort returns the proxy listening port.
func (p *WGEBPFProxy) GetProxyPort() uint16 {
return uint16(p.proxyPort)
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
@@ -241,3 +216,34 @@ generatePort:
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,89 +10,12 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
@@ -101,10 +24,8 @@ type ProxyWrapper struct {
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
paused bool
pausedCond *sync.Cond
@@ -120,32 +41,15 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -164,8 +68,7 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock()
p.paused = false
p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted {
p.isStarted = true
@@ -188,32 +91,12 @@ func (p *ProxyWrapper) Pause() {
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
if endpoint != nil && endpoint.IP != nil {
p.wgEndpointCurrentUsedAddr = endpoint
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
@@ -255,7 +138,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock()
if err != nil {
@@ -281,29 +164,3 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
return n, nil
}
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -54,14 +54,6 @@ func (w *KernelFactory) GetProxy() Proxy {
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active.
func (w *KernelFactory) GetProxyPort() uint16 {
if w.ebpfProxy == nil {
return 0
}
return w.ebpfProxy.GetProxyPort()
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil

View File

@@ -24,11 +24,6 @@ func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port.
func (w *USPFactory) GetProxyPort() uint16 {
return 0
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -8,87 +8,43 @@ import (
"os"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/client/net"
)
// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets
func PrepareSenderRawSocketIPv4() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET, true)
}
// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets
func PrepareSenderRawSocketIPv6() (net.PacketConn, error) {
return prepareSenderRawSocket(syscall.AF_INET6, false)
}
func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) {
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the header include option on the socket to tell the kernel that headers are included in the packet.
// For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers.
if isIPv4 {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
} else {
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
if closeErr := syscall.Close(fd); closeErr != nil {
log.Warnf("failed to close raw socket fd: %v", closeErr)
}
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file: %v", closeErr)
}
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
// Close the original file to release the FD (net.FilePacketConn duplicates it)
if closeErr := file.Close(); closeErr != nil {
log.Warnf("failed to close file after creating packet conn: %v", closeErr)
}
return packetConn, nil
}

View File

@@ -1,353 +0,0 @@
//go:build linux && !android
package wgproxy
import (
"context"
"net"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs
// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore
func compareUDPAddr(addr1, addr2 net.Addr) bool {
udpAddr1, ok1 := addr1.(*net.UDPAddr)
udpAddr2, ok2 := addr2.(*net.UDPAddr)
if !ok1 || !ok2 {
return addr1.String() == addr2.String()
}
// Compare IP and Port, ignoring zone
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses
func TestRedirectAs_UDP_IPv6(t *testing.T) {
wgPort := 51853
proxy := udp.NewWGUDPProxy(wgPort, 1280)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// testRedirectAs is a helper function that tests the RedirectAs functionality
// It verifies that:
// 1. Initial traffic from relay connection works
// 2. After calling RedirectAs, packets appear to come from the p2p endpoint
// 3. Multiple packets are correctly redirected with the new source address
func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) {
t.Helper()
ctx := context.Background()
// Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types
// In reality, WireGuard binds to a port and receives from both IPv4 and IPv6
wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv4 WireGuard listener: %v", err)
}
defer wgListener4.Close()
wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{
IP: net.ParseIP("::1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create IPv6 WireGuard listener: %v", err)
}
defer wgListener6.Close()
// Determine which listener to use based on the NetBird address IP version
// (this is where initial traffic will come from before RedirectAs is called)
var wgListener *net.UDPConn
if p2pEndpoint.IP.To4() == nil {
wgListener = wgListener6
} else {
wgListener = wgListener4
}
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0, // Random port
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
// Add TURN connection to proxy
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
// Start the proxy
proxy.Work()
// Phase 1: Test initial relay traffic
msgFromRelay := []byte("hello from relay")
if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write to relay server: %v", err)
}
// Set read deadline to avoid hanging
if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
buf := make([]byte, 1024)
n, _, err := wgListener4.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read from WireGuard listener: %v", err)
}
if n != len(msgFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n)
}
if string(buf[:n]) != string(msgFromRelay) {
t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n])
}
// Phase 2: Redirect to p2p endpoint
proxy.RedirectAs(p2pEndpoint)
// Give the proxy a moment to process the redirect
time.Sleep(100 * time.Millisecond)
// Phase 3: Test redirected traffic
redirectedMessages := [][]byte{
[]byte("redirected message 1"),
[]byte("redirected message 2"),
[]byte("redirected message 3"),
}
for i, msg := range redirectedMessages {
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write redirected message %d: %v", i+1, err)
}
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read redirected message %d: %v", i+1, err)
}
// Verify message content
if string(buf[:n]) != string(msg) {
t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n])
}
// Verify source address matches p2p endpoint (this is the key test)
// Use compareUDPAddr to ignore IPv6 zone IDs
if !compareUDPAddr(srcAddr, p2pEndpoint) {
t.Errorf("message %d: expected source address %s, got %s",
i+1, p2pEndpoint.String(), srcAddr.String())
}
}
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
ctx := context.Background()
// Create WireGuard listener
wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: wgPort,
})
if err != nil {
t.Fatalf("failed to create WireGuard listener: %v", err)
}
defer wgListener.Close()
// Create relay server and connection
relayServer, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
if err != nil {
t.Fatalf("failed to create relay server: %v", err)
}
defer relayServer.Close()
relayConn, err := net.Dial("udp", relayServer.LocalAddr().String())
if err != nil {
t.Fatalf("failed to create relay connection: %v", err)
}
defer relayConn.Close()
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil {
t.Fatalf("failed to add TURN connection: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("failed to close proxy connection: %v", err)
}
}()
proxy.Work()
// Test switching between multiple endpoints - using addresses in local subnet
endpoints := []*net.UDPAddr{
{IP: net.ParseIP("192.168.0.100"), Port: 51820},
{IP: net.ParseIP("192.168.0.101"), Port: 51821},
{IP: net.ParseIP("192.168.0.102"), Port: 51822},
}
for i, endpoint := range endpoints {
proxy.RedirectAs(endpoint)
time.Sleep(100 * time.Millisecond)
msg := []byte("test message")
if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil {
t.Fatalf("failed to write message for endpoint %d: %v", i, err)
}
buf := make([]byte, 1024)
if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("failed to set read deadline: %v", err)
}
n, srcAddr, err := wgListener.ReadFrom(buf)
if err != nil {
t.Fatalf("failed to read message for endpoint %d: %v", i, err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n])
}
if !compareUDPAddr(srcAddr, endpoint) {
t.Errorf("endpoint %d: expected source %s, got %s",
i, endpoint.String(), srcAddr.String())
}
}
}

View File

@@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {

View File

@@ -19,56 +19,37 @@ var (
FixLengths: true,
}
localHostNetIPAddrV4 = &net.IPAddr{
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
// Create only the raw socket for the address family we need
var rawSocket net.PacketConn
var err error
var localHostAddr *net.IPAddr
if srcAddr.IP.To4() != nil {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4()
localHostAddr = localHostNetIPAddrV4
} else {
rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6()
localHostAddr = localHostNetIPAddrV6
}
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
if closeErr := rawSocket.Close(); closeErr != nil {
log.Warnf("failed to close raw socket: %v", closeErr)
}
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
@@ -91,7 +72,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
@@ -99,40 +80,19 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
// Check if source IP is IPv4 or IPv6
if srcAddr.IP.To4() != nil {
// IPv4
ipv4 := &layers.IPv4{
DstIP: localHostNetIPAddrV4.IP,
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: localHostNetIPAddrV6.IP,
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(networkLayer)
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}

View File

@@ -189,212 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -1,499 +0,0 @@
package auth
import (
"context"
"net/url"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
mutex sync.RWMutex
client *mgm.GrpcClient
config *profilemanager.Config
privateKey wgtypes.Key
mgmURL *url.URL
mgmTLSEnabled bool
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The connection is automatically recreated with backoff if it becomes disconnected
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
privateKey: myPrivateKey,
mgmURL: mgmURL,
mgmTLSEnabled: mgmTLSEnabled,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
a.mutex.Lock()
defer a.mutex.Unlock()
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
var supportsSSO bool
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
// Try PKCE flow first
_, err := a.getPKCEFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if PKCE is not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if Device flow is also not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// Neither PKCE nor Device flow is supported
supportsSSO = false
return nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return err
})
return supportsSSO, err
}
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
// This avoids creating a new connection to the management server
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
var flow OAuthFlow
var err error
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
if forceDeviceAuth {
flow, err = a.getDeviceFlow(client)
return err
}
// Try PKCE flow first
flow, err = a.getPKCEFlow(client)
if err != nil {
// If PKCE not supported, try Device flow
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
flow, err = a.getDeviceFlow(client)
return err
}
return err
}
return nil
})
return flow, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
}
needsLogin = false
return err
})
return needsLogin, err
}
// Login attempts to log in or register the client with the management server
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return err, false
}
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
} else if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
isAuthError = false
return nil
})
return err, isAuthError
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
// setSystemInfoFlags sets all configuration flags on the provided system info
func (a *Auth) setSystemInfoFlags(info *system.Info) {
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
}
// reconnect closes the current connection and creates a new one
// It checks if the brokenClient is still the current client before reconnecting
// to avoid multiple threads reconnecting unnecessarily
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
a.mutex.Lock()
defer a.mutex.Unlock()
// Double-check: if client has already been replaced by another thread, skip reconnection
if a.client != brokenClient {
log.Debugf("client already reconnected by another thread, skipping")
return nil
}
// Create new connection FIRST, before closing the old one
// This ensures a.client is never nil, preventing panics in other threads
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
// Keep the old client if reconnection fails
return err
}
// Close old connection AFTER new one is successfully created
oldClient := a.client
a.client = mgmClient
if oldClient != nil {
if err := oldClient.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
return nil
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
// Capture the client BEFORE the operation to ensure we track the correct client
a.mutex.RLock()
currentClient := a.client
a.mutex.RUnlock()
if currentClient == nil {
return status.Errorf(codes.Unavailable, "client is not initialized")
}
// Execute operation with the captured client
err := operation(currentClient)
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection using the client that was actually used
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors, don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// isPermissionDenied checks if the error is a PermissionDenied error.
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
func isPermissionDenied(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.PermissionDenied
}
func isLoginNeeded(err error) bool {
return isAuthenticationError(err)
}
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -25,56 +26,12 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -100,7 +57,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -132,11 +89,6 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -247,22 +199,14 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,6 +12,8 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -113,19 +115,18 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -279,19 +280,18 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,6 +10,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -86,33 +87,19 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
return pkceFlowInfo, nil
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -127,9 +114,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
return deviceFlowInfo, nil
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -20,6 +20,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -34,67 +35,17 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig PKCEAuthProviderConfig
providerConfig internal.PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
excludedRanges := getSystemExcludedPortRanges()
@@ -173,21 +124,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -198,7 +138,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -209,8 +149,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo
go p.startServer(server, tokenChan, errChan)
select {
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -49,7 +50,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -9,6 +9,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
@@ -93,7 +95,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
availablePort := 65432
config := PKCEAuthProviderConfig{
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -245,7 +244,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
c.statusRecorder.UpdateLocalPeerState(localPeerState)

View File

@@ -0,0 +1,136 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -112,54 +112,6 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false,
shouldMatch: false,
},
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
}
for _, tt := range tests {

View File

@@ -9,10 +9,8 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -40,9 +38,6 @@ const (
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
}
func newHostManager() (*systemConfigurator, error) {
@@ -223,7 +218,6 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false
inServerAddressesArray := false
@@ -250,12 +244,9 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
@@ -267,19 +258,9 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port
dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil
}
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {

View File

@@ -109,169 +109,3 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -42,8 +42,6 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
@@ -200,11 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
// Update count even on error to ensure cleanup covers partially created rules
r.nrptEntryCount = count
if err != nil {
return fmt.Errorf("add dns match policy: %w", err)
}
r.nrptEntryCount = count
} else {
r.nrptEntryCount = 0
}
@@ -242,33 +239,23 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
// We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
singleDomain := []string{domain}
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
}
batchDomains := domains[i:end]
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
}
// Increment immediately so the caller's cleanup path knows about this rule
ruleIndex++
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
}
}
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
@@ -277,8 +264,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
return ruleIndex, nil
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return len(domains), nil
}
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -12,7 +12,6 @@ import (
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
@@ -38,60 +37,51 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
gpo: false,
}
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
domains125 := make([]DomainConfig, 125)
for i := 0; i < 125; i++ {
domains125[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config125 := HostDNSConfig{
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: domains125,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config125, nil)
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify 3 NRPT rules exist
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
for i := 0; i < 3; i++ {
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
domains75 := make([]DomainConfig, 75)
for i := 0; i < 75; i++ {
domains75[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config75 := HostDNSConfig{
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: domains75,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config75, nil)
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 NRPT rules exist
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify rule 2 is cleaned up
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
require.NoError(t, err)
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
@@ -107,106 +97,6 @@ func registryKeyExists(path string) (bool, error) {
}
func cleanupRegistryKeys(*testing.T) {
// Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -84,18 +84,3 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}
// CancelBatch mock implementation of CancelBatch from Server interface
func (m *MockServer) CancelBatch() {
// Mock implementation - no-op
}

View File

@@ -6,9 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
@@ -29,8 +27,6 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -45,9 +41,6 @@ type IosDnsManager interface {
type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
CancelBatch()
Initialize() error
Stop()
DnsIP() netip.Addr
@@ -90,7 +83,6 @@ type DefaultServer struct {
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver
@@ -238,9 +230,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
if !s.batchMode {
s.applyHostConfig()
}
s.applyHostConfig()
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -269,41 +259,9 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone)
}
}
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig()
}
// CancelBatch cancels batch mode without applying accumulated changes.
// This is useful when operations fail partway through and you want to
// discard partial state rather than applying it.
func (s *DefaultServer) CancelBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
s.batchMode = false
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
@@ -481,17 +439,6 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)
@@ -561,7 +508,6 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false
}
// Always apply host config for management updates, regardless of batch mode
s.applyHostConfig()
s.shutdownWg.Add(1)
@@ -669,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config)
}
// registerFallback registers original nameservers as low-priority fallback handlers.
// registerFallback registers original nameservers as low-priority fallback handlers
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
@@ -678,7 +624,6 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return
}
@@ -926,7 +871,6 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
// Always apply host config when nameserver goes down, regardless of batch mode
s.applyHostConfig()
go func() {
@@ -962,7 +906,6 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
}
// Always apply host config when nameserver reactivates, regardless of batch mode
s.applyHostConfig()
s.updateNSState(nsGroup, nil, true)

View File

@@ -18,12 +18,7 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err)
}
mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
if srvB != srv {
t.Errorf("mismatch dns instances")
}
}

View File

@@ -8,21 +8,15 @@ import (
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -190,75 +190,50 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return
return nil
}
question := query.Question[0]
qname := strings.ToLower(question.Name)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(qname, question.Qtype, result.IPs)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -268,7 +243,30 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -278,7 +276,18 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, w, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -325,7 +334,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
startTime time.Time,
) {
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
@@ -335,7 +343,9 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
@@ -345,7 +355,9 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
@@ -353,7 +365,9 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
}
@@ -361,12 +375,15 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
}
f.writeResponse(logger, w, resp, domain, startTime)
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,9 +318,8 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -330,9 +329,10 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
@@ -466,16 +466,14 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else {
require.NotNil(t, resp, "Expected response")
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
@@ -530,10 +528,9 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -608,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -678,8 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -687,13 +683,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -719,8 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -732,13 +727,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2)
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -789,9 +784,8 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -903,15 +897,26 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, resp.Answer, "Response should have no answer records")
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
@@ -926,8 +931,15 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{}
// Don't set any question
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -28,7 +28,6 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
@@ -506,10 +505,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// Set up notrack rules immediately after proxy is listening to prevent
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -544,12 +539,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() {
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart()
} else if err != nil {
@@ -575,11 +569,9 @@ func (e *Engine) createFirewall() error {
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil {
return fmt.Errorf("create firewall manager: %w", err)
}
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
}
if err := e.initFirewall(); err != nil {
@@ -625,23 +617,6 @@ func (e *Engine) initFirewall() error {
return nil
}
// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic.
// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy.
func (e *Engine) setupWGProxyNoTrack() {
if e.firewall == nil {
return
}
proxyPort := e.wgInterface.GetProxyPort()
if proxyPort == 0 {
return
}
if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil {
log.Warnf("failed to setup ebpf proxy notrack: %v", err)
}
}
func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
@@ -830,10 +805,6 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1023,7 +994,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn()
e.statusRecorder.UpdateLocalPeerState(state)
@@ -1673,7 +1644,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
func (e *Engine) close() {
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
@@ -1924,7 +1894,7 @@ func (e *Engine) triggerClientRestart() {
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
@@ -95,10 +94,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
if netstack.IsEnabled() {
return nil
}
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
@@ -221,10 +216,6 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
if netstack.IsEnabled() {
return
}
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {

View File

@@ -107,7 +107,6 @@ type MockWGIface struct {
GetStatsFunc func() (map[string]configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
GetProxyPortFunc func() uint16
GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
}
@@ -204,13 +203,6 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy {
return m.GetProxyFunc()
}
func (m *MockWGIface) GetProxyPort() uint16 {
if m.GetProxyPortFunc != nil {
return m.GetProxyPortFunc()
}
return 0
}
func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc()
}

View File

@@ -28,7 +28,6 @@ type wgIfaceBase interface {
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
GetProxyPort() uint16
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error

View File

@@ -11,7 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -75,13 +74,12 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is used on Windows, JS, and netstack platforms:
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}

201
client/internal/login.go Normal file
View File

@@ -0,0 +1,201 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -390,8 +390,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
@@ -404,13 +402,15 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
func (conn *Conn) onICEStateDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -430,10 +430,6 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
if conn.isReadyToUpgrade() {
conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay()
if sessionChanged {
conn.resetEndpoint()
}
conn.wgProxyRelay.Work()
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
@@ -505,9 +501,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -516,6 +509,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.enableWgWatcherIfNeeded()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
@@ -761,17 +756,6 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
return wgProxy, nil
}
func (conn *Conn) resetEndpoint() {
if !isController(conn.config) {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}
}
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
}

View File

@@ -66,10 +66,6 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) RemoveEndpointAddress() error {
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil {
return

View File

@@ -2,7 +2,6 @@ package ice
import (
"context"
"fmt"
"sync"
"time"
@@ -33,6 +32,24 @@ type ThreadSafeAgent struct {
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -76,41 +93,9 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {

View File

@@ -32,8 +32,6 @@ type WGWatcher struct {
enabled bool
muEnabled sync.RWMutex
resetCh chan struct{}
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -42,7 +40,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
wgIfaceStater: wgIfaceStater,
peerKey: peerKey,
stateDump: stateDump,
resetCh: make(chan struct{}, 1),
}
}
@@ -79,15 +76,6 @@ func (w *WGWatcher) IsEnabled() bool {
return w.enabled
}
// Reset signals the watcher that the WireGuard peer has been reset and a new
// handshake is expected. This restarts the handshake timeout from scratch.
func (w *WGWatcher) Reset() {
select {
case w.resetCh <- struct{}{}:
default:
}
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
@@ -117,12 +105,6 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-w.resetCh:
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
lastHandshake = time.Time{}
enabledTime = time.Now()
timer.Stop()
timer.Reset(wgHandshakeOvertime)
case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped")
return

View File

@@ -52,9 +52,8 @@ type WorkerICE struct {
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
remoteSessionChanged bool
muxAgent sync.Mutex
sessionID ICESessionID
muxAgent sync.Mutex
localUfrag string
localPwd string
@@ -107,12 +106,9 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.remoteSessionChanged = true
w.agentDialerCancel()
if w.agent != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
sessionID, err := NewICESessionID()
@@ -308,17 +304,13 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
cancel()
if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
sessionChanged := w.remoteSessionChanged
w.remoteSessionChanged = false
if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
@@ -331,7 +323,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
w.agentConnecting = false
w.remoteSessionID = ""
}
return sessionChanged
w.muxAgent.Unlock()
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -432,11 +424,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
sessionChanged := w.closeAgent(agent, dialerCancel)
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected(sessionChanged)
w.conn.onICEStateDisconnected()
}
default:
return

View File

@@ -0,0 +1,138 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultAdminURL)
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err

View File

@@ -173,21 +173,12 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
},
)
@@ -346,23 +337,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
batchStarted := false
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
batchStarted = true
defer func() {
if merr != nil {
// On error, cancel batch to discard partial DNS state
m.dnsServer.CancelBatch()
} else {
// On success, apply accumulated DNS changes
m.dnsServer.EndBatch()
}
}()
}
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
@@ -393,7 +367,6 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
m.activeRoutes[id] = handler
}
_ = batchStarted // Mark as used
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -4,17 +4,16 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
@@ -25,51 +24,42 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
if flags&unix.RTF_CLONING != 0 {
if flags&syscall.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C")
}
if flags&unix.RTF_WASCLONED != 0 {
if flags&syscall.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W")
}
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -4,18 +4,17 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
@@ -26,46 +25,37 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -9,8 +9,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -37,11 +35,6 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms")
}
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name")

View File

@@ -263,14 +263,7 @@ func (c *Client) IsLoginRequired() bool {
return true
}
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
return true // Assume login is required if we can't create auth client
}
defer authClient.Close()
needsLogin, err := authClient.IsLoginRequired(ctx)
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
if err != nil {
log.Errorf("IsLoginRequired: check failed: %v", err)
// If the check fails, assume login is required to be safe
@@ -321,19 +314,16 @@ func (c *Client) LoginForMobile() string {
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
return
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
log.Errorf("LoginForMobile: Login failed: %v", err)
return
}

View File

@@ -2,10 +2,7 @@
package NetBirdSDK
import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
@@ -35,13 +32,3 @@ func (el *EnvList) AllItems() map[string]string {
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
func GetEnvKeyNBLazyConn() string {
return lazyconn.EnvEnableLazyConn
}
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
func GetEnvKeyNBInactivityThreshold() string {
return lazyconn.EnvInactivityThreshold
}

View File

@@ -7,8 +7,13 @@ import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -85,21 +90,34 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
return err
}
return err
})
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
@@ -123,17 +141,19 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err, _ = authClient.Login(ctxWithValues, setupKey, "")
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
@@ -144,16 +164,15 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
// LoginSync performs a synchronous login check without UI interaction
// Used for background VPN connection where user should already be authenticated
func (a *Auth) LoginSync() error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
var needsLogin bool
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(a.ctx)
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
@@ -161,12 +180,15 @@ func (a *Auth) LoginSync() error {
return fmt.Errorf("not authenticated")
}
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
if err != nil {
if isAuthError {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
@@ -203,6 +225,8 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
}
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
var needsLogin bool
// Create context with device name if provided
ctx := a.ctx
if deviceName != "" {
@@ -210,33 +234,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
}
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(ctx)
err := a.withBackOff(ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err, isAuthError := authClient.Login(ctx, "", jwtToken)
if err != nil {
if isAuthError {
err = a.withBackOff(ctx, func() error {
err := internal.Login(ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
@@ -261,10 +285,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
const authInfoRequestTimeout = 30 * time.Second
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
if err != nil {
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
return nil, err
}
// Use a bounded timeout for the auth info request to prevent indefinite hangs
@@ -289,6 +313,15 @@ func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
// GetConfigJSON returns the current config as a JSON string.
// This can be used by the caller to persist the config via alternative storage
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).

View File

@@ -253,17 +253,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
if err != nil {
log.Errorf("failed to create auth client: %v", err)
return internal.StatusLoginFailed, err
}
defer authClient.Close()
var status internal.StatusType
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
err := internal.Login(ctx, s.config, setupKey, jwtToken)
if err != nil {
if isAuthError {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Warnf("failed login: %v", err)
status = internal.StatusNeedsLogin
} else {
@@ -588,7 +581,8 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel()
}
waitCTX, cancel := context.WithCancel(ctx)
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
s.mutex.Lock()

View File

@@ -207,6 +207,8 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
@@ -214,28 +216,10 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
<-session.Context().Done()
if err := serverSession.Shell(); err != nil {
log.Debugf("start shell: %v", err)
return
}
done := make(chan error, 1)
go func() {
done <- serverSession.Wait()
}()
select {
case <-session.Context().Done():
return
case err := <-done:
if err != nil {
log.Debugf("shell session: %v", err)
p.handleProxyExitCode(session, err)
}
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
}
}

View File

@@ -12,8 +12,8 @@ import (
log "github.com/sirupsen/logrus"
)
// handleExecution executes an SSH command or shell with privilege validation
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
// handleCommand executes an SSH command with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
hasPty := winCh != nil
commandType := "command"
@@ -23,7 +23,7 @@ func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privile
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
@@ -51,12 +51,13 @@ func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privile
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
@@ -65,28 +66,28 @@ func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheck
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
cmd, err := s.createSuCommand(session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
logger.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, func() {}, nil
}

View File

@@ -15,17 +15,17 @@ import (
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
return nil
}

View File

@@ -10,7 +10,6 @@ import (
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"sync"
@@ -100,52 +99,40 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
return isUtilLinux
}
// createSuCommand creates a command using su - for privilege switching.
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
if err := validateUsername(localUser.Username); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
args := []string{"-"}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username)
args = append(args, localUser.Username, "-c", command)
command := session.RawCommand()
if command != "" {
args = append(args, "-c", command)
}
logger.Debugf("creating su command: %s %v", suPath, args)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string.
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell}
return []string{shell, "-l"}
}
return []string{shell, "-c", cmdString}
}
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
cmd := exec.CommandContext(ctx, shell, args[1:]...)
cmd.Args[0] = "-" + filepath.Base(shell)
return cmd
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
@@ -167,7 +154,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil {
logger.Errorf("Pty command creation failed: %v", err)
@@ -257,6 +244,11 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
}()
go func() {
defer func() {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Debugf("session close error: %v", err)
}
}()
if _, err := io.Copy(session, ptmx); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty output copy error: %v", err)
@@ -276,7 +268,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
case <-ctx.Done():
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
case err := <-done:
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
s.handlePtyCommandCompletion(logger, session, err)
}
}
@@ -304,20 +296,17 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
}
}
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
if err != nil {
logger.Debugf("Pty command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
} else {
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
return
}
// Close PTY to unblock io.Copy goroutines
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close after completion: %v", err)
// Normal completion
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}

View File

@@ -20,32 +20,32 @@ import (
// getUserEnvironment retrieves the Windows environment for the target user.
// Follows OpenSSH's resilient approach with graceful degradation on failures.
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
userToken, err := s.getUserToken(logger, username, domain)
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
userToken, err := s.getUserToken(username, domain)
if err != nil {
return nil, fmt.Errorf("get user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
logger.Debugf("close user token: %v", err)
log.Debugf("close user token: %v", err)
}
}()
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
return s.getUserEnvironmentWithToken(userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil {
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
}
envMap := make(map[string]string)
if err := s.loadSystemEnvironment(envMap); err != nil {
logger.Debugf("failed to load system environment from registry: %v", err)
log.Debugf("failed to load system environment from registry: %v", err)
}
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken window
}
// getUserToken creates a user token for the specified user.
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper()
token, err := privilegeDropper.createToken(username, domain)
if err != nil {
return 0, fmt.Errorf("generate S4U user token: %w", err)
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
}
// prepareCommandEnv prepares environment variables for command execution on Windows
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
username, domain := s.parseUsername(localUser.Username)
userEnv, err := s.getUserEnvironment(logger, username, domain)
userEnv, err := s.getUserEnvironment(username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
@@ -267,16 +267,22 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
return env
}
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true
}
@@ -288,6 +294,11 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
return []string{shell, "-Command", cmdString}
}
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
}
type PtyExecutionRequest struct {
Shell string
Command string
@@ -297,25 +308,25 @@ type PtyExecutionRequest struct {
Domain string
}
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
privilegeDropper := NewPrivilegeDropper()
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
if err != nil {
return fmt.Errorf("create user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
logger.Debugf("close user token: %v", err)
log.Debugf("close user token: %v", err)
}
}()
server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
if err != nil {
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ()
}
@@ -337,8 +348,8 @@ func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req
Environment: userEnv,
}
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
}
func getUserHomeFromEnv(env []string) string {
@@ -360,8 +371,10 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
if err := cmd.Process.Kill(); err != nil {
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
logger.Debugf("kill process failed: %v", err)
}
}
@@ -376,7 +389,21 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
@@ -388,14 +415,14 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
req := PtyExecutionRequest{
Shell: shell,
Command: session.RawCommand(),
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)

View File

@@ -4,15 +4,12 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
@@ -26,67 +23,25 @@ import (
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
// binary with "ssh exec" args. We handle that here to properly execute commands and
// propagate exit codes.
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
runTestExecutor()
return
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers()
os.Exit(code)
}
// runTestExecutor emulates the netbird executor for tests.
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
func runTestExecutor() {
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
os.Exit(1)
}
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
shell := "/bin/sh"
var command string
for i := 3; i < len(os.Args); i++ {
switch os.Args[i] {
case "--shell":
if i+1 < len(os.Args) {
shell = os.Args[i+1]
i++
}
case "--cmd":
if i+1 < len(os.Args) {
command = os.Args[i+1]
i++
}
}
}
var cmd *exec.Cmd
if command == "" {
cmd = exec.Command(shell)
} else {
cmd = exec.Command(shell, "-c", command)
}
cmd.Args[0] = "-" + filepath.Base(shell)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
os.Exit(0)
}
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
func TestSSHServerCompatibility(t *testing.T) {
if testing.Short() {
@@ -450,171 +405,6 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
return createTempKeyFileFromBytes(t, privateKey)
}
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
// This ensures our implementation matches OpenSSH behavior for:
// - ssh host command (no PTY - default when no TTY)
// - ssh -T host command (explicit no PTY)
// - ssh -t host command (force PTY)
// - ssh -T host (no PTY shell - our implementation)
func TestSSHPtyModes(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH PTY mode tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
}
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
username := testutil.GetTestUsername(t)
baseArgs := []string{
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "BatchMode=yes",
}
t.Run("command_default_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
assert.Contains(t, string(output), "no_pty_default")
})
t.Run("command_explicit_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
assert.Contains(t, string(output), "explicit_no_pty")
})
t.Run("command_force_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
assert.Contains(t, string(output), "force_pty")
})
t.Run("shell_explicit_no_pty", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
cmd := exec.CommandContext(ctx, "ssh", args...)
stdin, err := cmd.StdinPipe()
require.NoError(t, err)
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
go func() {
defer stdin.Close()
time.Sleep(100 * time.Millisecond)
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
assert.NoError(t, err, "write echo command")
time.Sleep(100 * time.Millisecond)
_, err = stdin.Write([]byte("exit 0\n"))
assert.NoError(t, err, "write exit command")
}()
output, _ := io.ReadAll(stdout)
err = cmd.Wait()
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
assert.Contains(t, string(output), "shell_no_pty_test")
})
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "Command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
})
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "PTY command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
})
t.Run("stderr_works_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "stderr test failed")
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
})
t.Run("stderr_merged_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "PTY stderr test failed: %s", output)
assert.Contains(t, string(output), "stdout_msg")
assert.Contains(t, string(output), "stderr_msg")
})
}
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
func TestSSHServerFeatureCompatibility(t *testing.T) {
if testing.Short() {

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
@@ -36,35 +35,11 @@ type ExecutorConfig struct {
}
// PrivilegeDropper handles secure privilege dropping in child processes
type PrivilegeDropper struct {
logger *log.Entry
}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
type PrivilegeDropper struct{}
// NewPrivilegeDropper creates a new privilege dropper
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
@@ -108,7 +83,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
break
}
}
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
return exec.CommandContext(ctx, netbirdPath, args...), nil
}
@@ -231,22 +206,17 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
var execCmd *exec.Cmd
if config.Command == "" {
execCmd = exec.CommandContext(ctx, config.Shell)
} else {
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
os.Exit(ExitCodeSuccess)
}
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
execCmd.Stdin = os.Stdin
execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr
if config.Command == "" {
log.Tracef("executing login shell: %s", execCmd.Path)
} else {
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
}
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
if err := execCmd.Run(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {

View File

@@ -28,45 +28,22 @@ const (
)
type WindowsExecutorConfig struct {
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Pty bool
PtyWidth int
PtyHeight int
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Interactive bool
Pty bool
PtyWidth int
PtyHeight int
}
type PrivilegeDropper struct {
logger *log.Entry
}
type PrivilegeDropper struct{}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
var (
@@ -79,6 +56,7 @@ const (
// Common error messages
commandFlag = "-Command"
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
convertUsernameError = "convert username to UTF16: %w"
convertDomainError = "convert domain to UTF16: %w"
)
@@ -102,7 +80,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
shellArgs = []string{shell}
}
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
@@ -202,10 +180,10 @@ func newLsaString(s string) lsaString {
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)
pd := NewPrivilegeDropper(WithLogger(logger))
pd := NewPrivilegeDropper()
isDomainUser := !pd.isLocalUser(domain)
lsaHandle, err := initializeLsaConnection()
@@ -219,12 +197,12 @@ func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.H
return 0, err
}
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
if err != nil {
return 0, err
}
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
}
// buildUserCpn constructs the user principal name
@@ -332,21 +310,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
}
// prepareS4ULogonStructure creates the appropriate S4U logon structure
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
if isDomainUser {
return prepareDomainS4ULogon(logger, username, domain)
return prepareDomainS4ULogon(username, domain)
}
return prepareLocalS4ULogon(logger, username)
return prepareLocalS4ULogon(username)
}
// prepareDomainS4ULogon creates S4U logon structure for domain users
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
upn, err := lookupPrincipalName(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
}
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
upnUtf16, err := windows.UTF16FromString(upn)
if err != nil {
@@ -379,8 +357,8 @@ func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.P
}
// prepareLocalS4ULogon creates S4U logon structure for local users
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
usernameUtf16, err := windows.UTF16FromString(username)
if err != nil {
@@ -428,11 +406,11 @@ func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, u
}
// performS4ULogon executes the S4U logon operation
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
var tokenSource tokenSource
copy(tokenSource.SourceName[:], "netbird")
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
logger.Debugf("AllocateLocallyUniqueId failed")
log.Debugf("AllocateLocallyUniqueId failed")
}
originName := newLsaString("netbird")
@@ -463,7 +441,7 @@ func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId
if profile != 0 {
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
}
}
@@ -471,7 +449,7 @@ func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
}
logger.Debugf("created S4U %s token for user %s",
log.Debugf("created S4U %s token for user %s",
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
return token, nil
}
@@ -519,8 +497,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
// authenticateLocalUser handles authentication for local users
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(pd.log(), username, ".")
log.Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(username, ".")
if err != nil {
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
}
@@ -529,12 +507,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
// authenticateDomainUser handles authentication for domain users
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(pd.log(), username, domain)
log.Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(username, domain)
if err != nil {
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
}
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
return token, nil
}
@@ -548,7 +526,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
defer func() {
if err := windows.CloseHandle(token); err != nil {
pd.log().Debugf("close impersonation token: %v", err)
log.Debugf("close impersonation token: %v", err)
}
}()
@@ -586,7 +564,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
return cmd, primaryToken, nil
}
// createSuCommand creates a command using su - for privilege switching (Windows stub).
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
return nil, fmt.Errorf("su command not available on Windows")
}

View File

@@ -54,7 +54,7 @@ func TestJWTEnforcement(t *testing.T) {
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -88,7 +88,7 @@ func TestJWTEnforcement(t *testing.T) {
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer func() { require.NoError(t, serverNoJWT.Stop()) }()
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
@@ -213,7 +213,7 @@ func TestJWTDetection(t *testing.T) {
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -341,7 +341,7 @@ func TestJWTFailClose(t *testing.T) {
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -596,7 +596,7 @@ func TestJWTAuthentication(t *testing.T) {
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
@@ -715,7 +715,7 @@ func TestJWTMultipleAudiences(t *testing.T) {
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer func() { require.NoError(t, server.Stop()) }()
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)

View File

@@ -271,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding
}
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg

View File

@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
sessions = append(sessions, info)
}
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
for key, connState := range s.connections {
remoteAddr := string(key)
if reportedAddrs[remoteAddr] {

View File

@@ -483,11 +483,12 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
}
}
func TestServer_NonPtyShellSession(t *testing.T) {
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
func TestServer_PortForwardingOnlySession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
@@ -495,26 +496,36 @@ func TestServer_NonPtyShellSession(t *testing.T) {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
expectAllowed bool
description string
}{
{
name: "shell_with_local_forwarding_enabled",
name: "session_allowed_with_local_forwarding",
allowLocalForwarding: true,
allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
},
{
name: "shell_with_remote_forwarding_enabled",
name: "session_allowed_with_remote_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
},
{
name: "shell_with_both_forwarding_enabled",
name: "session_allowed_with_both",
allowLocalForwarding: true,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
},
{
name: "shell_with_forwarding_disabled",
name: "session_denied_without_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
},
}
@@ -534,6 +545,7 @@ func TestServer_NonPtyShellSession(t *testing.T) {
_ = server.Stop()
}()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -545,10 +557,20 @@ func TestServer_NonPtyShellSession(t *testing.T) {
_ = client.Close()
}()
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
// Should always succeed regardless of port forwarding settings
_, err = client.ExecuteCommand(ctx, "")
assert.NoError(t, err, "Non-PTY shell session should be allowed")
// Execute a command without PTY - this simulates ssh -T with no command
// The server should either allow it (port forwarding enabled) or reject it
output, err := client.ExecuteCommand(ctx, "")
if tt.expectAllowed {
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
})
}
}

View File

@@ -405,14 +405,12 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
} else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-c", args[1])
assert.Equal(t, "echo test", args[2])
args = server.getShellCommandArgs("/bin/sh", "")
assert.Equal(t, "/bin/sh", args[0])
assert.Len(t, args, 1)
assert.Equal(t, "-l", args[1])
assert.Equal(t, "-c", args[2])
assert.Equal(t, "echo test", args[3])
}
}

View File

@@ -62,12 +62,54 @@ func (s *Server) sessionHandler(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
} else {
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
switch {
case isPty && hasCommand:
// ssh -t <host> <cmd> - Pty command execution
s.handleCommand(logger, session, privilegeResult, winCh)
case isPty:
// ssh <host> - Pty interactive session (login)
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
case hasCommand:
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
// ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
}
}
// handleNonInteractiveSession handles sessions that have no PTY and no command.
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
}
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
}
}

Some files were not shown because too many files have changed in this diff Show More