Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
20bb4095f2 Bump github.com/docker/docker
Bumps [github.com/docker/docker](https://github.com/docker/docker) from 26.1.5+incompatible to 28.0.0+incompatible.
- [Release notes](https://github.com/docker/docker/releases)
- [Commits](https://github.com/docker/docker/compare/v26.1.5...v28.0.0)

---
updated-dependencies:
- dependency-name: github.com/docker/docker
  dependency-version: 28.0.0+incompatible
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-08-27 16:32:02 +00:00
372 changed files with 4046 additions and 15657 deletions

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386" - arch: "386"
raceFlag: "" raceFlag: ""
- arch: "amd64" - arch: "amd64"
raceFlag: "-race" raceFlag: ""
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -382,32 +382,6 @@ jobs:
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -454,10 +428,9 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \ CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http) -timeout 20m ./management/... ./shared/management/...
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -548,7 +521,7 @@ jobs:
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/server/http/... -timeout 20m ./management/... ./shared/management/...
api_integration_test: api_integration_test:
name: "Management / Integration" name: "Management / Integration"
@@ -598,4 +571,4 @@ jobs:
CI=true \ CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/server/http/... -timeout 20m ./management/... ./shared/management/...

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum skip: go.mod,go.sum
golangci: golangci:
strategy: strategy:

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.23" SIGN_PIPE_VER: "v0.0.22"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"

View File

@@ -1,67 +0,0 @@
name: Wasm
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:
name: "JS / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size
run: |
echo "Wasm build size:"
ls -lh netbird.wasm
SIZE=$(stat -c%s netbird.wasm)
SIZE_MB=$((SIZE / 1024 / 1024))
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

0
.gitmodules vendored
View File

View File

@@ -2,18 +2,6 @@ version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird - id: netbird
dir: client dir: client
binary: netbird binary: netbird
@@ -127,11 +115,6 @@ archives:
- builds: - builds:
- netbird - netbird
- netbird-static - netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>

View File

@@ -1,4 +1,3 @@
<div align="center"> <div align="center">
<br/> <br/>
<br/> <br/>
@@ -53,7 +52,7 @@
### Open Source Network Security in a Single Platform ### Open Source Network Security in a Single Platform
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 <img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
### NetBird on Lawrence Systems (Video) ### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2 FROM alpine:3.22.0
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \
@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="5" NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -4,7 +4,6 @@ package android
import ( import (
"context" "context"
"os"
"slices" "slices"
"sync" "sync"
@@ -19,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/util/net"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -84,8 +83,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
} }
// Run start the internal client. It is a blocker function // Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -120,8 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps. // In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
@@ -252,14 +249,3 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() { func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener() c.recorder.RemoveConnectionListener()
} }
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -1,32 +0,0 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -33,7 +33,6 @@ type ErrListener interface {
// the backend want to show an url for the user // the backend want to show an url for the user
type URLOpener interface { type URLOpener interface {
Open(string) Open(string)
OnLoginSuccess()
} }
// Auth can register or login new client // Auth can register or login new client
@@ -182,11 +181,6 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error { err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken) err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil return nil
} }

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil { if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
} }
@@ -303,18 +303,12 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string { func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context(), true) statusResp, err := getStatus(cmd.Context())
if err != nil { if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -1,8 +0,0 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
@@ -228,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
// update host's static platform and system information // update host's static platform and system information
system.UpdateStaticInfoAsync() system.UpdateStaticInfo()
configFilePath, err := activeProf.FilePath() configFilePath, err := activeProf.FilePath()
if err != nil { if err != nil {
@@ -357,21 +356,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {
if err := openBrowser(verificationURIComplete); err != nil { if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
} }
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server. // DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*10) ctx, cancel := context.WithTimeout(ctx, time.Second*3)
defer cancel() defer cancel()
return grpc.DialContext( return grpc.DialContext(

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint log.Info("starting NetBird service") //nolint
// Collect static system and platform information // Collect static system and platform information
system.UpdateStaticInfoAsync() system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, false) resp, err := getStatus(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }

View File

@@ -9,28 +9,29 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server" sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
) )
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
@@ -89,20 +90,15 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().
@@ -116,7 +112,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -230,9 +230,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{ status, err := client.Status(ctx, &proto.StatusRequest{})
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil { if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err) return fmt.Errorf("unable to get daemon status: %v", err)
} }

View File

@@ -23,29 +23,23 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started") var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance. // Client manages a netbird embedded client instance
type Client struct { type Client struct {
deviceName string deviceName string
config *profilemanager.Config config *profilemanager.Config
mu sync.Mutex mu sync.Mutex
cancel context.CancelFunc cancel context.CancelFunc
setupKey string setupKey string
jwtToken string
connect *internal.ConnectClient connect *internal.ConnectClient
} }
// Options configures a new Client. // Options configures a new Client
type Options struct { type Options struct {
// DeviceName is this peer's name in the network // DeviceName is this peer's name in the network
DeviceName string DeviceName string
// SetupKey is used for authentication // SetupKey is used for authentication
SetupKey string SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// PrivateKey is used for direct private key authentication
PrivateKey string
// ManagementURL overrides the default management server URL // ManagementURL overrides the default management server URL
ManagementURL string ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface // PreSharedKey is the pre-shared key for the WireGuard interface
@@ -64,35 +58,8 @@ type Options struct {
DisableClientRoutes bool DisableClientRoutes bool
} }
// validateCredentials checks that exactly one credential type is provided // New creates a new netbird embedded client
func (opts *Options) validateCredentials() error {
credentialsProvided := 0
if opts.SetupKey != "" {
credentialsProvided++
}
if opts.JWTToken != "" {
credentialsProvided++
}
if opts.PrivateKey != "" {
credentialsProvided++
}
if credentialsProvided == 0 {
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
}
if credentialsProvided > 1 {
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
}
return nil
}
// New creates a new netbird embedded client.
func New(opts Options) (*Client, error) { func New(opts Options) (*Client, error) {
if err := opts.validateCredentials(); err != nil {
return nil, err
}
if opts.LogOutput != nil { if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput) logrus.SetOutput(opts.LogOutput)
} }
@@ -140,14 +107,9 @@ func New(opts Options) (*Client, error) {
return nil, fmt.Errorf("create config: %w", err) return nil, fmt.Errorf("create config: %w", err)
} }
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey
}
return &Client{ return &Client{
deviceName: opts.DeviceName, deviceName: opts.DeviceName,
setupKey: opts.SetupKey, setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config, config: config,
}, nil }, nil
} }
@@ -164,7 +126,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background()) ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err) return fmt.Errorf("login: %w", err)
} }
@@ -173,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan struct{}) run := make(chan struct{}, 1)
clientErr := make(chan error, 1) clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
@@ -225,16 +187,6 @@ func (c *Client) Stop(ctx context.Context) error {
} }
} }
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network. // Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -259,7 +211,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address) return nsnet.DialContext(ctx, network, address)
} }
// ListenTCP listens on the given address in the netbird network. // ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) { func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet() nsnet, addr, err := c.getNet()
@@ -280,7 +232,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr) return nsnet.ListenTCP(tcpAddr)
} }
// ListenUDP listens on the given address in the netbird network. // ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) { func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet() nsnet, addr, err := c.getNet()

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -400,6 +400,7 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return "" return ""
} }
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := "" actionSuffix := ""
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
actionSuffix = "-drop" actionSuffix = "-drop"

View File

@@ -260,22 +260,6 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
@@ -880,54 +880,6 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {
if port == nil { if port == nil {
return nil return nil

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {

View File

@@ -151,20 +151,14 @@ type Manager interface {
DisableRouting() error DisableRouting() error
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. // AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error) AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes the outbound DNAT rule. // DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes // UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (

View File

@@ -376,22 +376,6 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -1350,103 +1350,6 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets // applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork( func (r *router) applyNetwork(
network firewall.Network, network firewall.Network,

View File

@@ -22,8 +22,6 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64 PacketsRx atomic.Uint64
BytesTx atomic.Uint64 BytesTx atomic.Uint64
BytesRx atomic.Uint64 BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -171,30 +171,28 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists { if exists {
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true return key, true
} }
return key, 0, false return key, false
} }
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed // TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
return 0 }
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 { if exists || flags&TCPSyn == 0 {
return return
} }
@@ -212,13 +210,8 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false) conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew)) conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key) t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
t.mutex.Lock() t.mutex.Lock()
@@ -456,21 +449,6 @@ func (t *TCPTracker) cleanup() {
} }
} }
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.tickerCancel() t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80) serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound) // 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{ key := ConnKey{
SrcIP: clientIP, SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake // 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer // 4. Test data transfer
// Client sends data // Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data // Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data // Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters // Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState()) require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,23 +58,20 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
if exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
return 0 }
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
} }
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -89,15 +86,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true return key, true
} }
return key, 0, false return key, false
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
} }
@@ -112,7 +109,6 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
} }
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
@@ -120,11 +116,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key) t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }

View File

@@ -109,10 +109,6 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex dnatMutex sync.RWMutex
dnatBiMap *biDNATMap dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
} }
// decoder for packages // decoder for packages
@@ -126,8 +122,6 @@ type decoder struct {
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
@@ -202,7 +196,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
} }
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
@@ -637,7 +630,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true return true
} }
m.trackOutbound(d, srcIP, dstIP, packetData, size) m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d) m.translateOutboundDNAT(packetData, d)
return false return false
@@ -681,26 +674,14 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
@@ -710,15 +691,13 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
d.dnatOrigPort = 0
} }
// udpHooksDrop checks if any UDP hooks should drop the packet // udpHooksDrop checks if any UDP hooks should drop the packet
@@ -780,20 +759,10 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
srcIP, dstIP = m.extractIPs(d) srcIP, dstIP = m.extractIPs(d)

View File

@@ -50,8 +50,6 @@ type logMessage struct {
arg4 any arg4 any
arg5 any arg5 any
arg6 any arg6 any
arg7 any
arg8 any
} }
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
@@ -96,6 +94,7 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) Error(format string) { func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
@@ -186,15 +185,6 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
} }
} }
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
@@ -249,16 +239,6 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
} }
} }
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0] *buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -280,12 +260,6 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++ argCount++
if msg.arg6 != nil { if msg.arg6 != nil {
argCount++ argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
} }
} }
} }
@@ -309,10 +283,6 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6: case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
} }
*buf = append(*buf, formatted...) *buf = append(*buf, formatted...)

View File

@@ -5,9 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -15,21 +13,6 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 { func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 { if len(header) < 20 {
return 0 return 0
@@ -69,7 +52,6 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 { func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32 var sum1, sum2, sum3, sum4 uint32
i := 0 i := 0
@@ -107,21 +89,11 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct { type biDNATMap struct {
forward map[netip.Addr]netip.Addr forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr
} }
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap { func newBiDNATMap() *biDNATMap {
return &biDNATMap{ return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr), forward: make(map[netip.Addr]netip.Addr),
@@ -129,13 +101,11 @@ func newBiDNATMap() *biDNATMap {
} }
} }
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) { func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated b.forward[original] = translated
b.reverse[translated] = original b.reverse[translated] = original
} }
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) { func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists { if translated, exists := b.forward[original]; exists {
delete(b.forward, original) delete(b.forward, original)
@@ -143,25 +113,19 @@ func (b *biDNATMap) delete(original netip.Addr) {
} }
} }
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original] translated, exists := b.forward[original]
return translated, exists return translated, exists
} }
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated] original, exists := b.reverse[translated]
return original, exists return original, exists
} }
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() { if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid original IP address") return fmt.Errorf("invalid IP addresses")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
} }
if m.localipmanager.IsLocalIP(translatedAddr) { if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -171,6 +135,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil { if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap() m.dnatBiMap = newBiDNATMap()
@@ -186,7 +151,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil return nil
} }
// RemoveInternalDNATMapping removes a 1:1 IP address mapping. // RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
@@ -204,7 +169,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil return nil
} }
// getDNATTranslation returns the translated address if a mapping exists. // getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return addr, false return addr, false
@@ -216,7 +181,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists return translated, exists
} }
// findReverseDNATMapping finds original address for return traffic. // findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return translatedAddr, false return translatedAddr, false
@@ -228,12 +193,16 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists return original, exists
} }
// translateOutboundDNAT applies DNAT translation to outbound packets. // translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP) translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -241,8 +210,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err) m.logger.Error1("Failed to rewrite packet destination: %v", err)
return false return false
} }
@@ -250,12 +219,16 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true return true
} }
// translateInboundReverse applies reverse DNAT to inbound return traffic. // translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP) originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -263,8 +236,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err) m.logger.Error1("Failed to rewrite packet source: %v", err)
return false return false
} }
@@ -272,21 +245,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true return true
} }
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. // rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if !newIP.Is4() { if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only return ErrIPv4Only
} }
var oldIP [4]byte var oldDst [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4]) copy(oldDst[:], packetData[16:20])
newIPBytes := newIP.As4() newDst := newIP.As4()
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4 ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength return fmt.Errorf("invalid IP header length")
} }
binary.BigEndian.PutUint16(packetData[10:12], 0) binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -296,9 +269,44 @@ func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Add
if len(d.decoded) > 1 { if len(d.decoded) > 1 {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen) m.updateICMPChecksum(packetData, ipHeaderLen)
} }
@@ -307,7 +315,6 @@ func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Add
return nil return nil
} }
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 { if len(packetData) < tcpStart+18 {
@@ -320,7 +327,6 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen udpStart := ipHeaderLen
if len(packetData) < udpStart+8 { if len(packetData) < udpStart+8 {
@@ -338,7 +344,6 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 { if len(packetData) < icmpStart+8 {
@@ -351,7 +356,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum) binary.BigEndian.PutUint16(icmpData[2:4], checksum)
} }
// incrementalUpdate performs incremental checksum update per RFC 1624. // incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum) sum := uint32(^oldChecksum)
@@ -386,7 +391,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. // AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil, errNatNotSupported return nil, errNatNotSupported
@@ -394,184 +399,10 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule) return m.nativeFirewall.AddDNATRule(rule)
} }
// DeleteDNATRule deletes outbound DNAT rule. // DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errNatNotSupported return errNatNotSupported
} }
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -414,127 +414,3 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
} }
}) })
} }
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -144,111 +143,3 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP) err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping") require.Error(t, err, "Should error when removing non-existent mapping")
} }
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,16 +16,12 @@ type PacketStage int
const ( const (
StageReceived PacketStage = iota StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack StageConntrack
StagePeerACL StagePeerACL
StageRouting StageRouting
StageRouteACL StageRouteACL
StageForwarding StageForwarding
StageCompleted StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
) )
const msgProcessingCompleted = "Processing completed" const msgProcessingCompleted = "Processing completed"
@@ -33,16 +29,12 @@ const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string { func (s PacketStage) String() string {
return map[PacketStage]string{ return map[PacketStage]string{
StageReceived: "Received", StageReceived: "Received",
StageInboundPortDNAT: "Inbound Port DNAT",
StageInbound1to1NAT: "Inbound 1:1 NAT",
StageConntrack: "Connection Tracking", StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL", StagePeerACL: "Peer ACL",
StageRouting: "Routing", StageRouting: "Routing",
StageRouteACL: "Route ACL", StageRouteACL: "Route ACL",
StageForwarding: "Forwarding", StageForwarding: "Forwarding",
StageCompleted: "Completed", StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s] }[s]
} }
@@ -269,10 +261,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
@@ -412,16 +400,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
} }
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
d := m.decoders.Get().(*decoder) // will create or update the connection state
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0) dropped := m.filterOutbound(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -430,199 +409,3 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
} }
return trace return trace
} }
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -104,8 +104,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -128,8 +126,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -157,8 +153,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -185,8 +179,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -212,8 +204,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -238,8 +228,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -258,8 +246,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -278,8 +264,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageCompleted, StageCompleted,
@@ -303,8 +287,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageCompleted, StageCompleted,
}, },
@@ -319,8 +301,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted, StageCompleted,
}, },
expectedAllow: true, expectedAllow: true,
@@ -339,8 +319,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -362,8 +340,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -386,8 +362,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -408,8 +382,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -434,8 +406,6 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
StageCompleted, StageCompleted,

View File

@@ -1,44 +0,0 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

View File

@@ -1,13 +0,0 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
}

View File

@@ -3,7 +3,7 @@ package bind
import ( import (
wireguard "golang.zx2c4.com/wireguard/conn" wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -1,17 +1,5 @@
package bind package bind
import ( import wgConn "golang.zx2c4.com/wireguard/conn"
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -1,7 +0,0 @@
package bind
import "fmt"
var (
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
)

View File

@@ -1,9 +1,6 @@
//go:build !js
package bind package bind
import ( import (
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
@@ -11,18 +8,22 @@ import (
"runtime" "runtime"
"sync" "sync"
"github.com/pion/stun/v3" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct { type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
} }
@@ -40,38 +41,37 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported. // use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct { type ICEBind struct {
*wgConn.StdNetBind *wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn udpmux.FilterFn filterFn FilterFn
address wgaddr.Address
mtu uint16
endpoints map[netip.Addr]net.Conn endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex endpointsMu sync.Mutex
recvChan chan recvMessage
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one // new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{} closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool closed bool
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn, filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
recvChan: make(chan recvMessage, 1),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
mtu: mtu,
address: address,
activityRecorder: NewActivityRecorder(), activityRecorder: NewActivityRecorder(),
} }
@@ -82,6 +82,10 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
return ib return ib
} }
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false s.closed = false
s.closedChanMu.Lock() s.closedChanMu.Lock()
@@ -111,7 +115,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
} }
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind // GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
if s.udpMux == nil { if s.udpMux == nil {
@@ -134,16 +138,6 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
delete(b.endpoints, fakeIP) delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-b.closedChan:
return
case <-ctx.Done():
return
case b.recvChan <- recvMessage{ep, buf}:
}
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock() b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()] conn, ok := b.endpoints[ep.DstIP()]
@@ -164,8 +158,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
s.udpMux = udpmux.NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn), UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
@@ -276,7 +270,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select { select {
case <-c.closedChan: case <-c.closedChan:
return 0, net.ErrClosed return 0, net.ErrClosed
case msg, ok := <-c.recvChan: case msg, ok := <-c.RecvChan:
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
} }

View File

@@ -1,6 +0,0 @@
package bind
type recvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

View File

@@ -1,125 +0,0 @@
package bind
import (
"context"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
)
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
// Do not limit to build only js, because we want to be able to run tests
type RelayBindJS struct {
*conn.StdNetBind
recvChan chan recvMessage
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
activityRecorder *ActivityRecorder
ctx context.Context
cancel context.CancelFunc
}
func NewRelayBindJS() *RelayBindJS {
return &RelayBindJS{
recvChan: make(chan recvMessage, 100),
endpoints: make(map[netip.Addr]net.Conn),
activityRecorder: NewActivityRecorder(),
}
}
// Open creates a receive function for handling relay packets in WASM.
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.ctx, s.cancel = context.WithCancel(context.Background())
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
select {
case <-s.ctx.Done():
return 0, net.ErrClosed
case msg, ok := <-s.recvChan:
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
func (s *RelayBindJS) Close() error {
if s.cancel == nil {
return nil
}
log.Debugf("close RelayBindJS")
s.cancel()
return nil
}
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-s.ctx.Done():
return
case <-ctx.Done():
return
case s.recvChan <- recvMessage{ep, buf}:
}
}
// Send forwards packets through the relay connection for WASM.
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
return err
}
}
return nil
}
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
}
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, ErrUDPMUXNotSupported
}
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}

View File

@@ -1,4 +1,4 @@
package udpmux package bind
import ( import (
"fmt" "fmt"
@@ -8,9 +8,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/pion/ice/v4" "github.com/pion/ice/v3"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192 const receiveMTU = 8192
// SingleSocketUDPMux is an implementation of the interface // UDPMuxDefault is an implementation of the interface
type SingleSocketUDPMux struct { type UDPMuxDefault struct {
params Params params UDPMuxParams
closedChan chan struct{} closedChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
@@ -32,9 +32,6 @@ type SingleSocketUDPMux struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn addressMap map[string][]*udpMuxedConn
@@ -49,8 +46,8 @@ type SingleSocketUDPMux struct {
const maxAddrSize = 512 const maxAddrSize = 512
// Params are parameters for UDPMux. // UDPMuxParams are parameters for UDPMux.
type Params struct { type UDPMuxParams struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
UDPConn net.PacketConn UDPConn net.PacketConn
@@ -150,18 +147,17 @@ func isZeros(ip net.IP) bool {
return true return true
} }
// NewSingleSocketUDPMux creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = getLogger() params.Logger = getLogger()
} }
mux := &SingleSocketUDPMux{ mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1), closedChan: make(chan struct{}, 1),
pool: &sync.Pool{ pool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
@@ -175,15 +171,15 @@ func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
return mux return mux
} }
func (m *SingleSocketUDPMux) updateLocalAddresses() { func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() { } else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection // it will break the applications that are already using unspecified UDP connection
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux. // with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType var networks []ice.NetworkType
switch { switch {
@@ -220,13 +216,13 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
m.mu.Unlock() m.mu.Unlock()
} }
// LocalAddr returns the listening address of this SingleSocketUDPMux // LocalAddr returns the listening address of this UDPMuxDefault
func (m *SingleSocketUDPMux) LocalAddr() net.Addr { func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr() return m.params.UDPConn.LocalAddr()
} }
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses() m.updateLocalAddresses()
m.mu.Lock() m.mu.Lock()
@@ -240,7 +236,7 @@ func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address // GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
m.mu.Lock() m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified) lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -264,14 +260,12 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
return conn, nil return conn, nil
} }
c := m.createMuxedConn(ufrag, candidateID) c := m.createMuxedConn(ufrag)
go func() { go func() {
<-c.CloseChannel() <-c.CloseChannel()
m.RemoveConnByUfrag(ufrag) m.RemoveConnByUfrag(ufrag)
}() }()
m.candidateConnMap[candidateID] = c
if isIPv6 { if isIPv6 {
m.connsIPv6[ufrag] = c m.connsIPv6[ufrag] = c
} else { } else {
@@ -282,7 +276,7 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2) removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock // Keep lock section small to avoid deadlock with conn lock
@@ -290,12 +284,10 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok { if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag) delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
if c, ok := m.connsIPv6[ufrag]; ok { if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag) delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
m.mu.Unlock() m.mu.Unlock()
@@ -322,7 +314,7 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
} }
// IsClosed returns true if the mux had been closed // IsClosed returns true if the mux had been closed
func (m *SingleSocketUDPMux) IsClosed() bool { func (m *UDPMuxDefault) IsClosed() bool {
select { select {
case <-m.closedChan: case <-m.closedChan:
return true return true
@@ -332,7 +324,7 @@ func (m *SingleSocketUDPMux) IsClosed() bool {
} }
// Close the mux, no further connections could be created // Close the mux, no further connections could be created
func (m *SingleSocketUDPMux) Close() error { func (m *UDPMuxDefault) Close() error {
var err error var err error
m.closeOnce.Do(func() { m.closeOnce.Do(func() {
m.mu.Lock() m.mu.Lock()
@@ -355,11 +347,11 @@ func (m *SingleSocketUDPMux) Close() error {
return err return err
} }
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr) return m.params.UDPConn.WriteTo(buf, rAddr)
} }
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() { if m.IsClosed() {
return return
} }
@@ -376,109 +368,81 @@ func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr str
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{ c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m, Mux: m,
Key: key, Key: key,
AddrPool: m.pool, AddrPool: m.pool,
LocalAddr: m.LocalAddr(), LocalAddr: m.LocalAddr(),
Logger: m.params.Logger, Logger: m.params.Logger,
CandidateID: candidateID,
}) })
return c return c
} }
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr) remoteAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr") return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
} }
// Try to route to specific candidate connection first // If we have already seen this address dispatch to the appropriate destination
if conn := m.findCandidateConnection(msg); conn != nil { // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
return conn.writePacket(msg.Raw, remoteAddr) // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
} // We will then forward STUN packets to each of these connections.
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock() m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.RUnlock() m.addressMapMu.RUnlock()
if conn, ok := m.findConnectionByUsername(msg, addr); ok { var isIPv6 bool
// If we have already seen this address dispatch to the appropriate destination if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one isIPv6 = true
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
} }
// Forward to all found connections // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList { for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err) log.Errorf("could not write packet: %v", err)
} }
} }
return nil return nil
} }
// findConnectionByUsername finds connection using username attribute from STUN message func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 { if isIPv6 {
val, ok = m.connsIPv6[ufrag] val, ok = m.connsIPv6[ufrag]
} else { } else {
@@ -487,13 +451,6 @@ func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedCo
return return
} }
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct { type bufferHolder struct {
buf []byte buf []byte
} }

View File

@@ -1,12 +1,12 @@
//go:build !ios //go:build !ios
package udpmux package bind
import ( import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr) conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package udpmux package bind
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v3" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing. // It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct { type UniversalUDPMuxDefault struct {
*SingleSocketUDPMux *UDPMuxDefault
params UniversalUDPMuxParams params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
udpMuxParams := Params{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
Net: m.params.Net, Net: m.params.Net,
} }
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m return m
} }
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
} }
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
} }
return nil return nil
} }
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
} }
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -1,4 +1,4 @@
package udpmux package bind
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,12 +16,11 @@ import (
) )
type udpMuxedConnParams struct { type udpMuxedConnParams struct {
Mux *SingleSocketUDPMux Mux *UDPMuxDefault
AddrPool *sync.Pool AddrPool *sync.Pool
Key string Key string
LocalAddr net.Addr LocalAddr net.Addr
Logger logging.LeveledLogger Logger logging.LeveledLogger
CandidateID string
} }
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -120,10 +119,6 @@ func (c *udpMuxedConn) Close() error {
return err return err
} }
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool { func (c *udpMuxedConn) isClosed() bool {
select { select {
case <-c.closedChan: case <-c.closedChan:

View File

@@ -73,44 +73,6 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd || js || wasip1 //go:build linux || windows || freebsd
package configurer package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows && !js //go:build !windows
package configurer package configurer

View File

@@ -1,23 +0,0 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -106,67 +106,6 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
@@ -455,13 +394,6 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
} }
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil return time.Unix(sec, 0), nil
} }
@@ -470,7 +402,7 @@ func toBytes(s string) (int64, error) {
} }
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark return nbnet.ControlPlaneMark
} }
return 0 return 0

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*udpmux.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16
@@ -23,5 +23,4 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device Device() *wgdevice.Device
GetNet() *netstack.Net GetNet() *netstack.Net
GetICEBind() device.EndpointManager
} }

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -30,7 +29,7 @@ type WGTunDevice struct {
name string name string
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -89,7 +88,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -150,11 +149,6 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string { func routesToString(routes []string) string {
return strings.Join(routes, ";") return strings.Join(routes, ";")
} }

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -27,7 +26,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -72,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -154,8 +153,3 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +28,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -84,7 +83,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -144,8 +143,3 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink link *wgLink
udpMuxConn net.PacketConn udpMuxConn net.PacketConn
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
filterFn udpmux.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil return configurer, nil
} }
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil { if t.udpMux != nil {
return t.udpMux, nil return t.udpMux, nil
} }
@@ -101,14 +101,19 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := udpmux.UniversalUDPMuxParams{ var udpConn net.PacketConn = rawSock
UDPConn: nbnet.WrapPacketConn(rawSock), if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,
MTU: t.mtu, MTU: t.mtu,
} }
mux := udpmux.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock t.udpMuxConn = rawSock
t.udpMux = mux t.udpMux = mux
@@ -179,8 +184,3 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net { func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -1,29 +1,19 @@
package device package device
import ( import (
"errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
EndpointManager
}
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address wgaddr.Address address wgaddr.Address
@@ -31,18 +21,18 @@ type TunNetstackDevice struct {
key string key string
mtu uint16 mtu uint16
listenAddress string listenAddress string
bind Bind iceBind *bind.ICEBind
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun nsTun *nbnetstack.NetStackTun
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -50,7 +40,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
key: key, key: key,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
bind: bind, iceBind: iceBind,
} }
} }
@@ -75,11 +65,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.filteredDevice, t.filteredDevice,
t.bind, t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
_ = tunIface.Close() _ = tunIface.Close()
@@ -90,7 +80,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }
@@ -100,15 +90,11 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
udpMux, err := t.bind.GetICEMux() udpMux, err := t.iceBind.GetICEMux()
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) { if err != nil {
return nil, err return nil, err
} }
if udpMux != nil {
t.udpMux = udpMux t.udpMux = udpMux
}
log.Debugf("netstack device is ready to use") log.Debugf("netstack device is ready to use")
return udpMux, nil return udpMux, nil
} }
@@ -156,8 +142,3 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net { func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net return t.net
} }
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -1,27 +0,0 @@
package device
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestNewNetstackDevice(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
relayBind := bind.NewRelayBindJS()
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
cfgr, err := nsTun.Create()
if err != nil {
t.Fatalf("failed to create netstack device: %v", err)
}
if cfgr == nil {
t.Fatal("expected non-nil configurer")
}
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -26,7 +25,7 @@ type USPDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -75,7 +74,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }
@@ -146,8 +145,3 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net { func (t *USPDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -30,7 +29,7 @@ type TunDevice struct {
device *device.Device device *device.Device
nativeTunDevice *tun.NativeTun nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -105,7 +104,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -185,8 +184,3 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net { func (t *TunDevice) GetNet() *netstack.Net {
return nil return nil
} }
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -1,13 +0,0 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,5 +21,4 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
} }

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*udpmux.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16
@@ -21,5 +21,4 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device Device() *wgdevice.Device
GetNet() *netstack.Net GetNet() *netstack.Net
GetICEBind() device.EndpointManager
} }

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16 MTU uint16
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn udpmux.FilterFn FilterFn bind.FilterFn
DisableDNS bool DisableDNS bool
} }
@@ -80,17 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy() return w.wgProxyFactory.GetProxy()
} }
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool { func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind return w.userspaceBind
@@ -125,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface // Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before) // The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -159,17 +148,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()

View File

@@ -1,6 +0,0 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil
} }
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -1,41 +0,0 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true, userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -1,27 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}
return wgIface, nil
}

View File

@@ -1,4 +1,4 @@
//go:build linux && !android //go:build (linux && !android) || freebsd
package iface package iface
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil return wgIFace, nil
} }
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: tun, tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -1,5 +1,3 @@
//go:build !js
package netstack package netstack
import ( import (

View File

@@ -1,12 +0,0 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -1,64 +0,0 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,7 +0,0 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -16,38 +16,28 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type Bind interface {
SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
}
type ProxyBind struct { type ProxyBind struct {
bind Bind Bind *bind.ICEBind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address fakeNetIP *netip.AddrPort
wgRelayedEndpoint *bind.Endpoint wgBindEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
mtu uint16
} }
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind { func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{ p := &ProxyBind{
bind: bind, Bind: bind,
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
} }
return p return p
@@ -56,25 +46,25 @@ func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return nil return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) { func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -86,21 +76,17 @@ func (p *ProxyBind) Work() {
return return
} }
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedCond.L.Lock() p.pausedMu.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once // Start the proxy only once
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
func (p *ProxyBind) Pause() { func (p *ProxyBind) Pause() {
@@ -108,25 +94,9 @@ func (p *ProxyBind) Pause() {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
p.paused = true p.paused = true
p.pausedCond.L.Unlock() p.pausedMu.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
} }
func (p *ProxyBind) CloseConn() error { func (p *ProxyBind) CloseConn() error {
@@ -137,10 +107,6 @@ func (p *ProxyBind) CloseConn() error {
} }
func (p *ProxyBind) close() error { func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -154,12 +120,7 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.pausedCond.L.Lock() p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -175,7 +136,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}() }()
for { for {
buf := make([]byte, p.mtu) buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -186,13 +147,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
for p.paused { if p.paused {
p.pausedCond.Wait() p.pausedMu.Unlock()
continue
} }
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n]) msg := bind.RecvMessage{
p.pausedCond.L.Unlock() Endpoint: p.wgBindEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
} }
} }

View File

@@ -6,7 +6,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"syscall"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -16,20 +18,15 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -67,7 +64,7 @@ func (p *WGEBPFProxy) Listen() error {
return err return err
} }
p.rawConn, err = rawsocket.PrepareSenderRawSocket() p.rawConn, err = p.prepareSenderRawSocket()
if err != nil { if err != nil {
return err return err
} }
@@ -217,17 +214,57 @@ generatePort:
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)
ipH := &layers.IPv4{ ipH := &layers.IPv4{
DstIP: localHostNetIP, DstIP: localhost,
SrcIP: endpointAddr.IP, SrcIP: localhost,
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
udpH := &layers.UDP{ udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port), SrcPort: layers.UDPPort(port),
DstPort: layers.UDPPort(p.localWGListenPort), DstPort: layers.UDPPort(p.localWGListenPort),
} }
@@ -242,7 +279,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) return fmt.Errorf("serialize layers: %w", err)
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) return fmt.Errorf("write to raw conn: %w", err)
} }
return nil return nil

View File

@@ -18,42 +18,41 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr wgEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{ return &ProxyWrapper{
wgeBPFProxy: proxy, WgeBPFProxy: WgeBPFProxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr p.wgEndpointAddr = addr
return err return err
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgRelayedEndpointAddr return p.wgEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -65,18 +64,14 @@ func (p *ProxyWrapper) Work() {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
func (p *ProxyWrapper) Pause() { func (p *ProxyWrapper) Pause() {
@@ -85,59 +80,45 @@ func (p *ProxyWrapper) Pause() {
} }
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedCond.L.Lock() p.pausedMu.Lock()
p.paused = true p.paused = true
p.pausedCond.L.Unlock() p.pausedMu.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error { func (e *ProxyWrapper) CloseConn() error {
if p.cancel == nil { if e.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
} }
p.cancel() e.cancel()
p.closeListener.SetCloseListener(nil) e.closeListener.SetCloseListener(nil)
p.pausedCond.L.Lock() if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.paused = false return fmt.Errorf("close remote conn: %w", err)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.readFromRemote(ctx, buf) n, err := p.readFromRemote(ctx, buf)
if err != nil { if err != nil {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
for p.paused { if p.paused {
p.pausedCond.Wait() p.pausedMu.Unlock()
continue
} }
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedCond.L.Unlock() p.pausedMu.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -156,7 +137,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
p.closeListener.Notify() p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
} }
return 0, err return 0, err
} }

View File

@@ -39,6 +39,7 @@ func (w *KernelFactory) GetProxy() Proxy {
} }
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -0,0 +1,31 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
mtu uint16
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -3,25 +3,24 @@ package wgproxy
import ( import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
) )
type USPFactory struct { type USPFactory struct {
bind proxyBind.Bind bind *bind.ICEBind
mtu uint16
} }
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory { func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy") log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{ f := &USPFactory{
bind: bind, bind: iceBind,
mtu: mtu,
} }
return f return f
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind, w.mtu) return proxyBind.NewProxyBind(w.bind)
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -11,11 +11,6 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
//and rewrite the src address to the endpoint address.
//With this logic can avoid the package loss from relayed connections.
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func()) SetDisconnectListener(disconnected func())
} }

View File

@@ -3,82 +3,54 @@
package wgproxy package wgproxy
import ( import (
"fmt" "context"
"net" "os"
"testing"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
func seedProxies() ([]proxyInstance, error) { func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
pl := make([]proxyInstance, 0) if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) t.Fatalf("failed to initialize ebpf proxy: %s", err)
} }
pEbpf := proxyInstance{ defer func() {
name: "ebpf kernel proxy", if err := ebpfProxy.Free(); err != nil {
proxy: ebpf.NewProxyWrapper(ebpfProxy), t.Errorf("failed to free ebpf proxy: %s", err)
wgPort: 51831,
closeFn: ebpfProxy.Free,
} }
pl = append(pl, pEbpf) }()
pUDP := proxyInstance{ tests := []struct {
name: "udp kernel proxy", name string
proxy: udp.NewWGUDPProxy(51832, 1280), proxy Proxy
wgPort: 51832, }{
closeFn: func() error { return nil }, {
} name: "ebpf proxy",
pl = append(pl, pUDP) proxy: &ebpf.ProxyWrapper{
return pl, nil WgeBPFProxy: ebpfProxy,
} },
},
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
} }
pEbpf := proxyInstance{ for _, tt := range tests {
name: "ebpf kernel proxy", t.Run(tt.name, func(t *testing.T) {
proxy: ebpf.NewProxyWrapper(ebpfProxy), relayedConn := newMockConn()
wgPort: 51831, err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil { if err != nil {
return nil, err t.Errorf("error: %v", err)
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
} }
pBind := proxyInstance{ _ = relayedConn.Close()
name: "bind proxy", if err := tt.proxy.CloseConn(); err != nil {
proxy: bindproxy.NewProxyBind(iceBind, 0), t.Errorf("error: %v", err)
endpointAddr: endpointAddress, }
closeFn: func() error { return nil }, })
} }
pl = append(pl, pBind)
return pl, nil
} }

View File

@@ -1,39 +0,0 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux
package wgproxy package wgproxy
import ( import (
@@ -5,9 +7,12 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -17,14 +22,6 @@ func TestMain(m *testing.M) {
os.Exit(code) os.Exit(code)
} }
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct { type mocConn struct {
closeChan chan struct{} closeChan chan struct{}
closed bool closed bool
@@ -81,21 +78,41 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) { func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tests, err := seedProxyForProxyCloseByRemoteConn() tests := []struct {
if err != nil { name string
t.Fatalf("error: %v", err) proxy Proxy
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
},
} }
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() { defer func() {
_ = relayedConn.Close() if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}() }()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn() relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }
@@ -107,104 +124,3 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}) })
} }
} }
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

View File

@@ -1,50 +0,0 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/client/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,5 +1,3 @@
//go:build linux && !android
package udp package udp
import ( import (
@@ -25,15 +23,13 @@ type WGUDPProxy struct {
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
srcFakerConn *SrcFaker
sendPkg func(data []byte) (int, error)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
@@ -45,7 +41,6 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu, mtu: mtu,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
return p return p
@@ -66,7 +61,6 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn p.remoteConn = remoteConn
return err return err
@@ -90,24 +84,15 @@ func (p *WGUDPProxy) Work() {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
p.paused = false p.paused = false
p.sendPkg = p.localConn.Write p.pausedMu.Unlock()
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToRemote(p.ctx) go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
} }
// Pause pauses the proxy from receiving data from the remote peer // Pause pauses the proxy from receiving data from the remote peer
@@ -116,35 +101,9 @@ func (p *WGUDPProxy) Pause() {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
p.paused = true p.paused = true
p.pausedCond.L.Unlock() p.pausedMu.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -156,8 +115,6 @@ func (p *WGUDPProxy) CloseConn() error {
} }
func (p *WGUDPProxy) close() error { func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -171,11 +128,7 @@ func (p *WGUDPProxy) close() error {
p.cancel() p.cancel()
p.pausedCond.L.Lock() var result *multierror.Error
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }
@@ -183,13 +136,6 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
@@ -248,12 +194,14 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedCond.L.Lock() p.pausedMu.Lock()
for p.paused { if p.paused {
p.pausedCond.Wait() p.pausedMu.Unlock()
continue
} }
_, err = p.sendPkg(buf[:n])
p.pausedCond.L.Unlock() _, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -1,101 +0,0 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -29,6 +29,11 @@ type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
} }
type protoMatch struct {
ips map[string]int
policyID []byte
}
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
firewall firewall.Manager firewall firewall.Manager
@@ -81,14 +86,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
} }
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := networkMap.FirewallRules rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := networkMap.PeerConfig != nil && enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
// If SSH enabled, add default firewall rule which accepts connection to any peer // if TCP protocol rules not squashed and SSH enabled
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort). // we add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP 22 port).
if enableSSH { if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{ rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
@@ -356,6 +368,145 @@ func (d *DefaultManager) getPeerRuleID(
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
} }
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector // getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)

View File

@@ -188,6 +188,492 @@ func TestDefaultManagerStateless(t *testing.T) {
}) })
} }
func TestDefaultManagerSquashRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) { func TestPortInfoEmpty(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -34,7 +34,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -275,7 +275,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock() c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { if err := c.engine.Start(); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
@@ -284,8 +284,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
close(runningChan) select {
runningChan = nil case runningChan <- struct{}{}:
default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()

Some files were not shown because too many files have changed in this diff Show More