Merge branch 'feat/auto-upgrade' into auto-upgrade-mod

This commit is contained in:
M Essam Hamed
2025-10-06 14:58:36 +03:00
245 changed files with 7275 additions and 2079 deletions

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: ""
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Install Go

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
skip: go.mod,go.sum
golangci:
strategy:

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.22"
SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -0,0 +1,67 @@
name: Wasm
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
js_lint:
name: "JS / Lint"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
with:
version: latest
install-mode: binary
skip-cache: true
skip-pkg-cache: true
skip-build-cache: true
- name: Run golangci-lint for WASM
run: |
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
continue-on-error: true
js_build:
name: "JS / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
CGO_ENABLED: 0
- name: Check Wasm build size
run: |
echo "Wasm build size:"
ls -lh netbird.wasm
SIZE=$(stat -c%s netbird.wasm)
SIZE_MB=$((SIZE / 1024 / 1024))
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
exit 1
fi

0
.gitmodules vendored Normal file
View File

View File

@@ -2,6 +2,18 @@ version: 2
project_name: netbird
builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird
dir: client
binary: netbird
@@ -115,6 +127,11 @@ archives:
- builds:
- netbird
- netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
nfpms:
- maintainer: Netbird <dev@netbird.io>

View File

@@ -1,3 +1,4 @@
<div align="center">
<br/>
<br/>
@@ -52,7 +53,7 @@
### Open Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -4,6 +4,7 @@ package android
import (
"context"
"os"
"slices"
"sync"
@@ -18,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/net"
)
// ConnectionListener export internal Listener for mobile
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -0,0 +1,32 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -33,6 +33,7 @@ type ErrListener interface {
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}

8
client/cmd/debug_js.go Normal file
View File

@@ -0,0 +1,8 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
return grpc.DialContext(

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
@@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}

View File

@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}

View File

@@ -23,23 +23,29 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
config *profilemanager.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
jwtToken string
connect *internal.ConnectClient
}
// Options configures a new Client
// Options configures a new Client.
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// PrivateKey is used for direct private key authentication
PrivateKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
@@ -58,8 +64,35 @@ type Options struct {
DisableClientRoutes bool
}
// New creates a new netbird embedded client
// validateCredentials checks that exactly one credential type is provided
func (opts *Options) validateCredentials() error {
credentialsProvided := 0
if opts.SetupKey != "" {
credentialsProvided++
}
if opts.JWTToken != "" {
credentialsProvided++
}
if opts.PrivateKey != "" {
credentialsProvided++
}
if credentialsProvided == 0 {
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
}
if credentialsProvided > 1 {
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
}
return nil
}
// New creates a new netbird embedded client.
func New(opts Options) (*Client, error) {
if err := opts.validateCredentials(); err != nil {
return nil, err
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
return nil, fmt.Errorf("create config: %w", err)
}
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
}, nil
}
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
@@ -135,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
}
}
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// ListenUDP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// constants needed to manage and create iptable rules

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func isIptablesSupported() bool {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -4,15 +4,9 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os/user"
"runtime"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -21,35 +15,9 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
@@ -57,7 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
@@ -67,18 +37,20 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
}
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool,
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

View File

@@ -0,0 +1,44 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/client/net"
)
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

13
client/grpc/dialer_js.go Normal file
View File

@@ -0,0 +1,13 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
}

View File

@@ -3,7 +3,7 @@ package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -1,5 +1,17 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -0,0 +1,7 @@
package bind
import "fmt"
var (
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
)

View File

@@ -1,6 +1,9 @@
//go:build !js
package bind
import (
"context"
"encoding/binary"
"fmt"
"net"
@@ -15,15 +18,11 @@ import (
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
@@ -41,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
filterFn udpmux.FilterFn
address wgaddr.Address
mtu uint16
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
recvChan chan recvMessage
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
activityRecorder *ActivityRecorder
muUDPMux sync.Mutex
udpMux *udpmux.UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn),
recvChan: make(chan recvMessage, 1),
closedChan: make(chan struct{}),
closed: true,
mtu: mtu,
address: address,
activityRecorder: NewActivityRecorder(),
}
@@ -82,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
return ib
}
func (s *ICEBind) MTU() uint16 {
return s.mtu
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
@@ -115,7 +111,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
@@ -138,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
delete(b.endpoints, fakeIP)
}
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-b.closedChan:
return
case <-ctx.Done():
return
case b.recvChan <- recvMessage{ep, buf}:
}
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
@@ -158,8 +164,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,
@@ -270,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
case msg, ok := <-c.recvChan:
if !ok {
return 0, net.ErrClosed
}

View File

@@ -0,0 +1,6 @@
package bind
type recvMessage struct {
Endpoint *Endpoint
Buffer []byte
}

View File

@@ -0,0 +1,125 @@
package bind
import (
"context"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
)
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
// Do not limit to build only js, because we want to be able to run tests
type RelayBindJS struct {
*conn.StdNetBind
recvChan chan recvMessage
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
activityRecorder *ActivityRecorder
ctx context.Context
cancel context.CancelFunc
}
func NewRelayBindJS() *RelayBindJS {
return &RelayBindJS{
recvChan: make(chan recvMessage, 100),
endpoints: make(map[netip.Addr]net.Conn),
activityRecorder: NewActivityRecorder(),
}
}
// Open creates a receive function for handling relay packets in WASM.
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.ctx, s.cancel = context.WithCancel(context.Background())
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
select {
case <-s.ctx.Done():
return 0, net.ErrClosed
case msg, ok := <-s.recvChan:
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
func (s *RelayBindJS) Close() error {
if s.cancel == nil {
return nil
}
log.Debugf("close RelayBindJS")
s.cancel()
return nil
}
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
select {
case <-s.ctx.Done():
return
case <-ctx.Done():
return
case s.recvChan <- recvMessage{ep, buf}:
}
}
// Send forwards packets through the relay connection for WASM.
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
return err
}
}
return nil
}
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock()
}
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, ErrUDPMUXNotSupported
}
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd
//go:build linux || windows || freebsd || js || wasip1
package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows
//go:build !windows && !js
package configurer

View File

@@ -0,0 +1,23 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil
}
@@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) {
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark
}
return 0

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +30,7 @@ type WGTunDevice struct {
name string
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
}
return t.configurer, nil
}
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -26,7 +27,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -28,7 +29,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
filterFn bind.FilterFn
filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil
}
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux

View File

@@ -1,19 +1,28 @@
package device
import (
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
}
type TunNetstackDevice struct {
name string
address wgaddr.Address
@@ -21,18 +30,18 @@ type TunNetstackDevice struct {
key string
mtu uint16
listenAddress string
iceBind *bind.ICEBind
bind Bind
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
}
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -40,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: iceBind,
bind: bind,
}
}
@@ -65,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
t.bind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
@@ -80,7 +89,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -90,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
udpMux, err := t.bind.GetICEMux()
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
return nil, err
}
t.udpMux = udpMux
if udpMux != nil {
t.udpMux = udpMux
}
log.Debugf("netstack device is ready to use")
return udpMux, nil
}

View File

@@ -0,0 +1,27 @@
package device
import (
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestNewNetstackDevice(t *testing.T) {
privateKey, _ := wgtypes.GeneratePrivateKey()
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
relayBind := bind.NewRelayBindJS()
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
cfgr, err := nsTun.Create()
if err != nil {
t.Fatalf("failed to create netstack device: %v", err)
}
if cfgr == nil {
t.Fatal("expected non-nil configurer")
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -25,7 +26,7 @@ type USPDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +30,7 @@ type TunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
FilterFn udpmux.FilterFn
DisableDNS bool
}
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()

View File

@@ -0,0 +1,6 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,41 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,27 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
relayBind := bind.NewRelayBindJS()
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
}
return wgIface, nil
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build linux && !android
package iface
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
return wgIFace, nil
}

View File

@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
}
return wgIFace, nil

View File

@@ -1,3 +1,5 @@
//go:build !js
package netstack
import (

View File

@@ -0,0 +1,12 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
Mux *SingleSocketUDPMux
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
CandidateID string
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err
}
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:

View File

@@ -0,0 +1,64 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
import (
"fmt"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
// SingleSocketUDPMux is an implementation of the interface
type SingleSocketUDPMux struct {
params Params
closedChan chan struct{}
closeOnce sync.Once
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
// Params are parameters for UDPMux.
type Params struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
@@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool {
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
// NewSingleSocketUDPMux creates an implementation of UDPMux
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil {
params.Logger = getLogger()
}
mux := &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
// big enough buffer to fit both packet and address
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock()
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
// LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil
}
c := m.createMuxedConn(ufrag)
c := m.createMuxedConn(ufrag, candidateID)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
m.candidateConnMap[candidateID] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
m.mu.Unlock()
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
func (m *SingleSocketUDPMux) IsClosed() bool {
select {
case <-m.closedChan:
return true
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
func (m *SingleSocketUDPMux) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
CandidateID: candidateID,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.RLock()
// Try to route to specific candidate connection first
if conn := m.findCandidateConnection(msg); conn != nil {
return conn.writePacket(msg.Raw, remoteAddr)
}
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
// Forward to all found connections
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
// findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return
}
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct {
buf []byte
}

View File

@@ -1,12 +1,12 @@
//go:build !ios
package bind
package udpmux
import (
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
*SingleSocketUDPMux
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
udpMuxParams := UDPMuxParams{
udpMuxParams := Params{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m
}
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -16,28 +16,38 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type ProxyBind struct {
Bind *bind.ICEBind
fakeNetIP *netip.AddrPort
wgBindEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
closeListener *listener.CloseListener
type Bind interface {
SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
type ProxyBind struct {
bind Bind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
mtu uint16
}
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
p := &ProxyBind{
Bind: bind,
bind: bind,
closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
}
return p
@@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
if err != nil {
return err
}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -76,17 +86,21 @@ func (p *ProxyBind) Work() {
return
}
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) Pause() {
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error {
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock()
defer p.closeMu.Unlock()
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}()
for {
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
buf := make([]byte, p.mtu)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
@@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
msg := bind.RecvMessage{
Endpoint: p.wgBindEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
p.pausedCond.L.Unlock()
}
}

View File

@@ -6,9 +6,7 @@ import (
"context"
"fmt"
"net"
"os"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -18,15 +16,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
return err
}
p.rawConn, err = p.prepareSenderRawSocket()
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil {
return err
}
@@ -214,57 +217,17 @@ generatePort:
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localhost,
SrcIP: localhost,
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(port),
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil

View File

@@ -18,41 +18,42 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr
p.wgRelayedEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
return p.wgRelayedEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) Pause() {
@@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
p.cancel()
e.closeListener.SetCloseListener(nil)
p.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("close remote conn: %w", err)
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock()
if err != nil {
if ctx.Err() != nil {
@@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
}
return 0, err
}

View File

@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
func (w *KernelFactory) Free() error {

View File

@@ -1,31 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
mtu uint16
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -3,24 +3,25 @@ package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind *bind.ICEBind
bind proxyBind.Bind
mtu uint16
}
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: iceBind,
bind: bind,
mtu: mtu,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return proxyBind.NewProxyBind(w.bind)
return proxyBind.NewProxyBind(w.bind, w.mtu)
}
func (w *USPFactory) Free() error {

View File

@@ -11,6 +11,11 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
//and rewrite the src address to the endpoint address.
//With this logic can avoid the package loss from relayed connections.
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
}

View File

@@ -3,54 +3,82 @@
package wgproxy
import (
"context"
"os"
"testing"
"fmt"
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
func seedProxies() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -0,0 +1,39 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind, 0),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,5 +1,3 @@
//go:build linux
package wgproxy
import (
@@ -7,12 +5,9 @@ import (
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util"
)
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct {
closeChan chan struct{}
closed bool
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
},
tests, err := seedProxyForProxyCloseByRemoteConn()
if err != nil {
t.Fatalf("error: %v", err)
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = relayedConn.Close()
}()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
})
}
}
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

View File

@@ -0,0 +1,50 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/client/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux && !android
package udp
import (
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
localWGListenPort int
mtu uint16
remoteConn net.Conn
localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
remoteConn net.Conn
localConn net.Conn
srcFakerConn *SrcFaker
sendPkg func(data []byte) (int, error)
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
}
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
p := &WGUDPProxy{
localWGListenPort: wgPort,
mtu: mtu,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(),
}
return p
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn
return err
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.sendPkg = p.localConn.Write
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
// Pause pauses the proxy from receiving data from the remote peer
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
}
// CloseConn close the localConn
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
}
func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock()
defer p.closeMu.Unlock()
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
p.cancel()
var result *multierror.Error
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result)
}
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
_, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
_, err = p.sendPkg(buf[:n])
p.pausedCond.L.Unlock()
if err != nil {
if ctx.Err() != nil {

View File

@@ -0,0 +1,101 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -35,7 +35,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -297,10 +297,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
state.Set(StatusConnected)
if runningChan != nil {
select {
case runningChan <- struct{}{}:
default:
}
close(runningChan)
runningChan = nil
}
<-engineCtx.Done()

View File

@@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
if r.gpo {
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
}
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
}
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
}
}
log.Debugf("added NRPT entry for domain: %s", domain)
@@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
}
@@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
}

View File

@@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (hostManager, error) {
return &noopHostConfigurator{}, nil
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type ServiceViaMemory struct {

View File

@@ -0,0 +1,19 @@
package dns
import (
"context"
)
type ShutdownState struct{}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
return nil
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type upstreamResolver struct {

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -11,14 +12,18 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
listenPort uint16 = 5353
listenPortMu sync.RWMutex
)
const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds
dnsTTL = 60 //seconds
)
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
@@ -35,12 +40,20 @@ type Manager struct {
fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder
port uint16
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
func ListenPort() uint16 {
listenPortMu.RLock()
defer listenPortMu.RUnlock()
return listenPort
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
return &Manager{
firewall: fw,
statusRecorder: statusRecorder,
port: port,
}
}
@@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
if m.port > 0 {
listenPortMu.Lock()
listenPort = m.port
listenPortMu.Unlock()
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists
@@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []uint16{ListenPort},
Values: []uint16{ListenPort()},
}
if m.firewall == nil {

View File

@@ -29,9 +29,9 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -168,7 +168,7 @@ type Engine struct {
wgInterface WGIface
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -203,6 +203,13 @@ type Engine struct {
// auto-update
updateManager *updatemanager.UpdateManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
}
// Peer is an instance of the Connection Peer
@@ -245,6 +252,7 @@ func NewEngine(
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(),
}
sm := profilemanager.NewServiceManager("")
@@ -350,6 +358,9 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
return nil
}
@@ -466,14 +477,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("initialize dns server: %w", err)
}
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
iceCfg := e.createICEConfig()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr.Start(e.ctx)
@@ -486,6 +490,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
}()
return nil
}
@@ -977,7 +997,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.LazyConnectionEnabled,
)
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
@@ -988,7 +1007,7 @@ func (e *Engine) receiveManagementEvents() {
}
log.Debugf("stopped receiving updates from Management Service")
}()
log.Debugf("connecting to Management Service updates stream")
log.Infof("connecting to Management Service updates stream")
}
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
@@ -1093,7 +1112,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
// Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1351,14 +1370,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
Addr: e.getRosenpassAddr(),
PermissiveMode: e.config.RosenpassPermissive,
},
ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
ICEConfig: e.createICEConfig(),
}
serviceDependencies := peer.ServiceDependencies{
@@ -1859,6 +1871,7 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) {
if e.config.DisableServerRoutes {
return
@@ -1875,16 +1888,20 @@ func (e *Engine) updateDNSForwarder(
}
if len(fwdEntries) > 0 {
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
log.Infof("started domain router service with %d entries", len(fwdEntries))
} else {
case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default:
e.dnsForwardMgr.UpdateDomains(fwdEntries)
}
} else if e.dnsForwardMgr != nil {
@@ -1894,6 +1911,20 @@ func (e *Engine) updateDNSForwarder(
}
e.dnsForwardMgr = nil
}
}
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
}
func (e *Engine) GetNet() (*netstack.Net, error) {

View File

@@ -0,0 +1,19 @@
//go:build !js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for non-WASM environments
func (e *Engine) createICEConfig() icemaker.Config {
return icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
}

View File

@@ -0,0 +1,18 @@
//go:build js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for WASM environment.
func (e *Engine) createICEConfig() icemaker.Config {
cfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
return cfg
}

View File

@@ -26,10 +26,15 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -41,10 +46,8 @@ import (
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -84,7 +87,7 @@ type MockWGIface struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
@@ -134,7 +137,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
@@ -413,7 +416,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
@@ -1583,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err
}

View File

@@ -9,9 +9,9 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
Name() string
Address() wgaddr.Address
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error

View File

@@ -14,7 +14,7 @@ import (
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const defaultChannelSize = 100

View File

@@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
return false
}

View File

@@ -0,0 +1,12 @@
package networkmonitor
import (
"context"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
// No-op for WASM - network changes don't apply
return nil
}

View File

@@ -6,7 +6,6 @@ import (
"math/rand"
"net"
"net/netip"
"os"
"runtime"
"sync"
"time"
@@ -29,10 +28,6 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
const (
defaultWgKeepAlive = 25 * time.Second
)
type ServiceDependencies struct {
StatusRecorder *Status
Signaler *Signaler
@@ -118,6 +113,8 @@ type Conn struct {
// debug purpose
dumpState *stateDump
endpointUpdater *EndpointUpdater
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -130,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
}
return conn, nil
@@ -174,7 +172,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if os.Getenv("NB_FORCE_RELAY") != "true" {
if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
@@ -250,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgProxyICE = nil
}
if err := conn.removeWgPeer(); err != nil {
if err := conn.endpointUpdater.RemoveWgPeer(); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
@@ -376,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
wgProxy.Work()
}
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
conn.Log.Debugf("redirect packets from relayed conn to WireGuard")
conn.wgProxyRelay.RedirectAs(ep)
}
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
@@ -410,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() {
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
@@ -419,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
@@ -478,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
}
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
}
@@ -546,17 +554,6 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
presharedKey := conn.presharedKey(remoteRPKey)
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
presharedKey,
)
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
@@ -699,10 +696,6 @@ func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {

View File

@@ -0,0 +1,105 @@
package peer
import (
"context"
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
const (
defaultWgKeepAlive = 25 * time.Second
fallbackDelay = 5 * time.Second
)
type EndpointUpdater struct {
log *logrus.Entry
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
}
func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater {
return &EndpointUpdater{
log: log,
wgConfig: wgConfig,
initiator: initiator,
}
}
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.initiator {
e.log.Debugf("configure up WireGuard as initiatr")
return e.updateWireGuardPeer(addr, presharedKey)
}
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
e.log.Debugf("configure up WireGuard and wait for handshake")
return e.updateWireGuardPeer(nil, presharedKey)
}
func (e *EndpointUpdater) RemoveWgPeer() error {
e.mu.Lock()
defer e.mu.Unlock()
e.waitForCloseTheDelayedUpdate()
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil {
return
}
e.cancelFunc()
e.cancelFunc = nil
e.updateWg.Wait()
}
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) {
defer e.updateWg.Done()
t := time.NewTimer(fallbackDelay)
defer t.Stop()
select {
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}
func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error {
return e.wgConfig.WgInterface.UpdatePeer(
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
presharedKey,
)
}

View File

@@ -0,0 +1,14 @@
package peer
import (
"os"
"strings"
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}

View File

@@ -3,6 +3,8 @@ package guard
import (
"context"
"fmt"
"slices"
"sort"
"sync"
"time"
@@ -24,8 +26,8 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config
currentCandidates []ice.Candidate
candidatesMu sync.Mutex
currentCandidatesAddress []string
candidatesMu sync.Mutex
}
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
@@ -115,16 +117,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
cm.candidatesMu.Lock()
defer cm.candidatesMu.Unlock()
if len(cm.currentCandidates) != len(newCandidates) {
cm.currentCandidates = newCandidates
newAddresses := make([]string, len(newCandidates))
for i, c := range newCandidates {
newAddresses[i] = c.Address()
}
sort.Strings(newAddresses)
if len(cm.currentCandidatesAddress) != len(newAddresses) {
cm.currentCandidatesAddress = newAddresses
return true
}
for i, candidate := range cm.currentCandidates {
if candidate.Address() != newCandidates[i].Address() {
cm.currentCandidates = newCandidates
return true
}
// Compare elements
if !slices.Equal(cm.currentCandidatesAddress, newAddresses) {
cm.currentCandidatesAddress = newAddresses
return true
}
return false

View File

@@ -30,9 +30,10 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
@@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
onDisconnectedFn()
return
}
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
lastHandshake = *handshake
resetTime := time.Until(handshake.Add(checkPeriod))

View File

@@ -9,11 +9,10 @@ import (
"time"
"github.com/pion/ice/v4"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -55,10 +54,6 @@ type WorkerICE struct {
sessionID ICESessionID
muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string
localPwd string
@@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.muxAgent.Unlock()
return
}
w.sentExtraSrflx = false
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
@@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
w.log.Errorf("error while handling remote candidate")
return
}
if shouldAddExtraCandidate(candidate) {
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
@@ -209,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
w.onICESelectedCandidatePair(agent, c1, c2)
}); err != nil {
return nil, err
}
@@ -327,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
return
}
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
@@ -354,48 +365,19 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
w.muxAgent.Lock()
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
w.log.Warnf("failed to get selected candidate pair: %s", err)
w.muxAgent.Unlock()
pairStat, ok := agent.GetSelectedCandidatePairStats()
if !ok {
w.log.Warnf("failed to get selected candidate pair stats")
return
}
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.muxAgent.Unlock()
return
}
w.muxAgent.Unlock()
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
@@ -424,22 +406,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
if candidate.Type() != ice.CandidateTypeServerReflexive {
return false
}
if candidate.Port() == candidate.RelatedAddress().Port {
return false
}
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
// in newer version we generate locally the extra candidate
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
return false
}
return true
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
@@ -455,6 +446,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
}
for _, e := range candidate.Extensions() {
// overwrite the original candidate ID with the new one to avoid candidate duplication
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = candidate.ID()
}
if err := ec.AddExtension(e); err != nil {
return nil, err
}

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// ProbeResult holds the info about the result of a relay probe request

View File

@@ -24,8 +24,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const dnsTimeout = 8 * time.Second
@@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()

View File

@@ -36,9 +36,9 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
}
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
return nil
}
if err := m.sysOps.CleanupRouting(nil); err != nil {
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
ips := resolveURLsToIPs(initialAddresses)
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
if runtime.GOOS == "windows" {
nbnet.SetVPNInterfaceName("")
}
}
m.mux.Lock()

Some files were not shown because too many files have changed in this diff Show More