mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
Compare commits
26 Commits
v0.57.0
...
cli-ws-pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdf4f10d94 | ||
|
|
2b8a9f55c1 | ||
|
|
e7b5537dcc | ||
|
|
95794f53ce | ||
|
|
9bcd3ebed4 | ||
|
|
b85045e723 | ||
|
|
4d7e59f199 | ||
|
|
b5daec3b51 | ||
|
|
5e1a40c33f | ||
|
|
e8d301fdc9 | ||
|
|
17bab881f7 | ||
|
|
25ed58328a | ||
|
|
644ed4b934 | ||
|
|
58faa341d2 | ||
|
|
5853b5553c | ||
|
|
998fb30e1e | ||
|
|
e254b4cde5 | ||
|
|
ead1c618ba | ||
|
|
55126f990c | ||
|
|
90577682e4 | ||
|
|
dc30dcacce | ||
|
|
2c87fa6236 | ||
|
|
ec8d83ade4 | ||
|
|
3130cce72d | ||
|
|
bd23ab925e | ||
|
|
0c6f671a7c |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -217,7 +217,7 @@ jobs:
|
|||||||
- arch: "386"
|
- arch: "386"
|
||||||
raceFlag: ""
|
raceFlag: ""
|
||||||
- arch: "amd64"
|
- arch: "amd64"
|
||||||
raceFlag: ""
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.22"
|
SIGN_PIPE_VER: "v0.0.23"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
67
.github/workflows/wasm-build-validation.yml
vendored
Normal file
67
.github/workflows/wasm-build-validation.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
name: Wasm
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
js_lint:
|
||||||
|
name: "JS / Lint"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
|
- name: Install golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||||
|
with:
|
||||||
|
version: latest
|
||||||
|
install-mode: binary
|
||||||
|
skip-cache: true
|
||||||
|
skip-pkg-cache: true
|
||||||
|
skip-build-cache: true
|
||||||
|
- name: Run golangci-lint for WASM
|
||||||
|
run: |
|
||||||
|
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
js_build:
|
||||||
|
name: "JS / Build"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
- name: Build Wasm client
|
||||||
|
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||||
|
env:
|
||||||
|
CGO_ENABLED: 0
|
||||||
|
- name: Check Wasm build size
|
||||||
|
run: |
|
||||||
|
echo "Wasm build size:"
|
||||||
|
ls -lh netbird.wasm
|
||||||
|
|
||||||
|
SIZE=$(stat -c%s netbird.wasm)
|
||||||
|
SIZE_MB=$((SIZE / 1024 / 1024))
|
||||||
|
|
||||||
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
|
if [ ${SIZE} -gt 52428800 ]; then
|
||||||
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
@@ -2,6 +2,18 @@ version: 2
|
|||||||
|
|
||||||
project_name: netbird
|
project_name: netbird
|
||||||
builds:
|
builds:
|
||||||
|
- id: netbird-wasm
|
||||||
|
dir: client/wasm/cmd
|
||||||
|
binary: netbird
|
||||||
|
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
|
||||||
|
goos:
|
||||||
|
- js
|
||||||
|
goarch:
|
||||||
|
- wasm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird
|
- id: netbird
|
||||||
dir: client
|
dir: client
|
||||||
binary: netbird
|
binary: netbird
|
||||||
@@ -115,6 +127,11 @@ archives:
|
|||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
- netbird-static
|
- netbird-static
|
||||||
|
- id: netbird-wasm
|
||||||
|
builds:
|
||||||
|
- netbird-wasm
|
||||||
|
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||||
|
format: binary
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<br/>
|
<br/>
|
||||||
<br/>
|
<br/>
|
||||||
@@ -52,7 +53,7 @@
|
|||||||
|
|
||||||
### Open Source Network Security in a Single Platform
|
### Open Source Network Security in a Single Platform
|
||||||
|
|
||||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### NetBird on Lawrence Systems (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ ENV \
|
|||||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ type ErrListener interface {
|
|||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(string)
|
||||||
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth can register or login new client
|
// Auth can register or login new client
|
||||||
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
}
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
8
client/cmd/debug_js.go
Normal file
8
client/cmd/debug_js.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// SetupDebugHandler is a no-op for WASM
|
||||||
|
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
|
||||||
|
// Debug handler not needed for WASM
|
||||||
|
}
|
||||||
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
|||||||
|
|
||||||
// DialClientGRPCServer returns client connection to the daemon server.
|
// DialClientGRPCServer returns client connection to the daemon server.
|
||||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return grpc.DialContext(
|
return grpc.DialContext(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||||
|
WaitForReady: func() *bool { b := true; return &b }(),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,23 +23,29 @@ import (
|
|||||||
|
|
||||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
var ErrClientNotStarted = errors.New("client not started")
|
var ErrClientNotStarted = errors.New("client not started")
|
||||||
|
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
setupKey string
|
setupKey string
|
||||||
|
jwtToken string
|
||||||
connect *internal.ConnectClient
|
connect *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options configures a new Client
|
// Options configures a new Client.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// DeviceName is this peer's name in the network
|
// DeviceName is this peer's name in the network
|
||||||
DeviceName string
|
DeviceName string
|
||||||
// SetupKey is used for authentication
|
// SetupKey is used for authentication
|
||||||
SetupKey string
|
SetupKey string
|
||||||
|
// JWTToken is used for JWT-based authentication
|
||||||
|
JWTToken string
|
||||||
|
// PrivateKey is used for direct private key authentication
|
||||||
|
PrivateKey string
|
||||||
// ManagementURL overrides the default management server URL
|
// ManagementURL overrides the default management server URL
|
||||||
ManagementURL string
|
ManagementURL string
|
||||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||||
@@ -58,8 +64,35 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new netbird embedded client
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
|
func (opts *Options) validateCredentials() error {
|
||||||
|
credentialsProvided := 0
|
||||||
|
if opts.SetupKey != "" {
|
||||||
|
credentialsProvided++
|
||||||
|
}
|
||||||
|
if opts.JWTToken != "" {
|
||||||
|
credentialsProvided++
|
||||||
|
}
|
||||||
|
if opts.PrivateKey != "" {
|
||||||
|
credentialsProvided++
|
||||||
|
}
|
||||||
|
|
||||||
|
if credentialsProvided == 0 {
|
||||||
|
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
|
||||||
|
}
|
||||||
|
if credentialsProvided > 1 {
|
||||||
|
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new netbird embedded client.
|
||||||
func New(opts Options) (*Client, error) {
|
func New(opts Options) (*Client, error) {
|
||||||
|
if err := opts.validateCredentials(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if opts.LogOutput != nil {
|
if opts.LogOutput != nil {
|
||||||
logrus.SetOutput(opts.LogOutput)
|
logrus.SetOutput(opts.LogOutput)
|
||||||
}
|
}
|
||||||
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
|
|||||||
return nil, fmt.Errorf("create config: %w", err)
|
return nil, fmt.Errorf("create config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.PrivateKey != "" {
|
||||||
|
config.PrivateKey = opts.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
deviceName: opts.DeviceName,
|
deviceName: opts.DeviceName,
|
||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
ctx := internal.CtxInitState(context.Background())
|
ctx := internal.CtxInitState(context.Background())
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
run := make(chan struct{}, 1)
|
run := make(chan struct{})
|
||||||
clientErr := make(chan error, 1)
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run); err != nil {
|
if err := client.Run(run); err != nil {
|
||||||
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfig returns a copy of the internal client config.
|
||||||
|
func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.config == nil {
|
||||||
|
return profilemanager.Config{}, ErrConfigNotInitialized
|
||||||
|
}
|
||||||
|
return *c.config, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Dial dials a network address in the netbird network.
|
// Dial dials a network address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
|||||||
return nsnet.DialContext(ctx, network, address)
|
return nsnet.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenTCP listens on the given address in the netbird network
|
// ListenTCP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
nsnet, addr, err := c.getNet()
|
nsnet, addr, err := c.getNet()
|
||||||
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
|||||||
return nsnet.ListenTCP(tcpAddr)
|
return nsnet.ListenTCP(tcpAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP listens on the given address in the netbird network
|
// ListenUDP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||||
nsnet, addr, err := c.getNet()
|
nsnet, addr, err := c.getNet()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
func isIptablesSupported() bool {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -4,15 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os/user"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -21,35 +15,9 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
currentUser, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
|
||||||
if currentUser.Uid != "0" {
|
|
||||||
log.Debug("Not running as root, using standard dialer")
|
|
||||||
dialer := &net.Dialer{}
|
|
||||||
return dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
b.MaxElapsedTime = 10 * time.Second
|
b.MaxElapsedTime = 10 * time.Second
|
||||||
@@ -57,7 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
if tlsEnabled {
|
if tlsEnabled {
|
||||||
certPool, err := x509.SystemCertPool()
|
certPool, err := x509.SystemCertPool()
|
||||||
@@ -67,18 +37,20 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||||
|
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
||||||
|
InsecureSkipVerify: runtime.GOOS == "js",
|
||||||
RootCAs: certPool,
|
RootCAs: certPool,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
conn, err := grpc.DialContext(
|
||||||
connCtx,
|
connCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
44
client/grpc/dialer_generic.go
Normal file
44
client/grpc/dialer_generic.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||||
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||||
|
if currentUser.Uid != "0" {
|
||||||
|
log.Debug("Not running as root, using standard dialer")
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to dial: %s", err)
|
||||||
|
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
13
client/grpc/dialer_js.go
Normal file
13
client/grpc/dialer_js.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util/wsproxy/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
|
||||||
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
|
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||||
|
return client.WithWebSocketDialer(tlsEnabled, component)
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@ package bind
|
|||||||
import (
|
import (
|
||||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||||
|
|||||||
@@ -1,5 +1,17 @@
|
|||||||
package bind
|
package bind
|
||||||
|
|
||||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
type Endpoint = wgConn.StdNetEndpoint
|
type Endpoint = wgConn.StdNetEndpoint
|
||||||
|
|
||||||
|
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: e.Addr().AsSlice(),
|
||||||
|
Port: int(e.Port()),
|
||||||
|
Zone: e.Addr().Zone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
7
client/iface/bind/error.go
Normal file
7
client/iface/bind/error.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
|
||||||
|
)
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -17,14 +20,9 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
|
||||||
Endpoint *Endpoint
|
|
||||||
Buffer []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type receiverCreator struct {
|
type receiverCreator struct {
|
||||||
iceBind *ICEBind
|
iceBind *ICEBind
|
||||||
}
|
}
|
||||||
@@ -42,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
|
|||||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||||
type ICEBind struct {
|
type ICEBind struct {
|
||||||
*wgConn.StdNetBind
|
*wgConn.StdNetBind
|
||||||
RecvChan chan RecvMessage
|
|
||||||
|
|
||||||
transportNet transport.Net
|
transportNet transport.Net
|
||||||
filterFn udpmux.FilterFn
|
filterFn udpmux.FilterFn
|
||||||
|
address wgaddr.Address
|
||||||
|
mtu uint16
|
||||||
|
|
||||||
endpoints map[netip.Addr]net.Conn
|
endpoints map[netip.Addr]net.Conn
|
||||||
endpointsMu sync.Mutex
|
endpointsMu sync.Mutex
|
||||||
|
recvChan chan recvMessage
|
||||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||||
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||||
closedChan chan struct{}
|
closedChan chan struct{}
|
||||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||||
closed bool
|
closed bool
|
||||||
|
activityRecorder *ActivityRecorder
|
||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
address wgaddr.Address
|
|
||||||
mtu uint16
|
|
||||||
activityRecorder *ActivityRecorder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
RecvChan: make(chan RecvMessage, 1),
|
|
||||||
transportNet: transportNet,
|
transportNet: transportNet,
|
||||||
filterFn: filterFn,
|
filterFn: filterFn,
|
||||||
|
address: address,
|
||||||
|
mtu: mtu,
|
||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
recvChan: make(chan recvMessage, 1),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
mtu: mtu,
|
|
||||||
address: address,
|
|
||||||
activityRecorder: NewActivityRecorder(),
|
activityRecorder: NewActivityRecorder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
|
|||||||
return ib
|
return ib
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ICEBind) MTU() uint16 {
|
|
||||||
return s.mtu
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||||
s.closed = false
|
s.closed = false
|
||||||
s.closedChanMu.Lock()
|
s.closedChanMu.Lock()
|
||||||
@@ -139,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
|||||||
delete(b.endpoints, fakeIP)
|
delete(b.endpoints, fakeIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||||
|
select {
|
||||||
|
case <-b.closedChan:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case b.recvChan <- recvMessage{ep, buf}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
b.endpointsMu.Lock()
|
b.endpointsMu.Lock()
|
||||||
conn, ok := b.endpoints[ep.DstIP()]
|
conn, ok := b.endpoints[ep.DstIP()]
|
||||||
@@ -271,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
select {
|
select {
|
||||||
case <-c.closedChan:
|
case <-c.closedChan:
|
||||||
return 0, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
case msg, ok := <-c.RecvChan:
|
case msg, ok := <-c.recvChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, net.ErrClosed
|
return 0, net.ErrClosed
|
||||||
}
|
}
|
||||||
|
|||||||
6
client/iface/bind/recv_msg.go
Normal file
6
client/iface/bind/recv_msg.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
type recvMessage struct {
|
||||||
|
Endpoint *Endpoint
|
||||||
|
Buffer []byte
|
||||||
|
}
|
||||||
125
client/iface/bind/relay_bind.go
Normal file
125
client/iface/bind/relay_bind.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
|
||||||
|
// Do not limit to build only js, because we want to be able to run tests
|
||||||
|
type RelayBindJS struct {
|
||||||
|
*conn.StdNetBind
|
||||||
|
|
||||||
|
recvChan chan recvMessage
|
||||||
|
endpoints map[netip.Addr]net.Conn
|
||||||
|
endpointsMu sync.Mutex
|
||||||
|
activityRecorder *ActivityRecorder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRelayBindJS() *RelayBindJS {
|
||||||
|
return &RelayBindJS{
|
||||||
|
recvChan: make(chan recvMessage, 100),
|
||||||
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
activityRecorder: NewActivityRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open creates a receive function for handling relay packets in WASM.
|
||||||
|
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||||
|
log.Debugf("Open: creating receive function for port %d", uport)
|
||||||
|
|
||||||
|
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
case msg, ok := <-s.recvChan:
|
||||||
|
if !ok {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
copy(bufs[0], msg.Buffer)
|
||||||
|
sizes[0] = len(msg.Buffer)
|
||||||
|
eps[0] = conn.Endpoint(msg.Endpoint)
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Open: receive function created, returning port %d", uport)
|
||||||
|
return []conn.ReceiveFunc{receiveFn}, uport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) Close() error {
|
||||||
|
if s.cancel == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("close RelayBindJS")
|
||||||
|
s.cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case s.recvChan <- recvMessage{ep, buf}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send forwards packets through the relay connection for WASM.
|
||||||
|
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
|
if ep == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP := ep.DstIP()
|
||||||
|
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
relayConn, ok := s.endpoints[fakeIP]
|
||||||
|
s.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, buf := range bufs {
|
||||||
|
if _, err := relayConn.Write(buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||||
|
b.endpointsMu.Lock()
|
||||||
|
b.endpoints[fakeIP] = conn
|
||||||
|
b.endpointsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
defer s.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
delete(s.endpoints, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
|
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
|
return nil, ErrUDPMUXNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
|
||||||
|
return s.activityRecorder
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux || windows || freebsd
|
//go:build linux || windows || freebsd || js || wasip1
|
||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !windows
|
//go:build !windows && !js
|
||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
|
|||||||
23
client/iface/configurer/uapi_js.go
Normal file
23
client/iface/configurer/uapi_js.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type noopListener struct{}
|
||||||
|
|
||||||
|
func (n *noopListener) Accept() (net.Conn, error) {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopListener) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopListener) Addr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openUAPI(deviceName string) (net.Listener, error) {
|
||||||
|
return &noopListener{}, nil
|
||||||
|
}
|
||||||
@@ -17,8 +17,8 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -409,7 +409,7 @@ func toBytes(s string) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/sharedsock"
|
"github.com/netbirdio/netbird/sharedsock"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunKernelDevice struct {
|
type TunKernelDevice struct {
|
||||||
@@ -101,13 +101,8 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var udpConn net.PacketConn = rawSock
|
|
||||||
if !nbnet.AdvancedRouting() {
|
|
||||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
|
||||||
}
|
|
||||||
|
|
||||||
bindParams := udpmux.UniversalUDPMuxParams{
|
bindParams := udpmux.UniversalUDPMuxParams{
|
||||||
UDPConn: udpConn,
|
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||||
Net: t.transportNet,
|
Net: t.transportNet,
|
||||||
FilterFn: t.filterFn,
|
FilterFn: t.filterFn,
|
||||||
WGAddress: t.address,
|
WGAddress: t.address,
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
@@ -12,9 +14,15 @@ import (
|
|||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Bind interface {
|
||||||
|
conn.Bind
|
||||||
|
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
|
ActivityRecorder() *bind.ActivityRecorder
|
||||||
|
}
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
name string
|
name string
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
@@ -22,7 +30,7 @@ type TunNetstackDevice struct {
|
|||||||
key string
|
key string
|
||||||
mtu uint16
|
mtu uint16
|
||||||
listenAddress string
|
listenAddress string
|
||||||
iceBind *bind.ICEBind
|
bind Bind
|
||||||
|
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
@@ -33,7 +41,7 @@ type TunNetstackDevice struct {
|
|||||||
net *netstack.Net
|
net *netstack.Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
|
||||||
return &TunNetstackDevice{
|
return &TunNetstackDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
|
|||||||
key: key,
|
key: key,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
iceBind: iceBind,
|
bind: bind,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
|
|
||||||
t.device = device.NewDevice(
|
t.device = device.NewDevice(
|
||||||
t.filteredDevice,
|
t.filteredDevice,
|
||||||
t.iceBind,
|
t.bind,
|
||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
@@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
udpMux, err := t.iceBind.GetICEMux()
|
udpMux, err := t.bind.GetICEMux()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if udpMux != nil {
|
||||||
t.udpMux = udpMux
|
t.udpMux = udpMux
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("netstack device is ready to use")
|
log.Debugf("netstack device is ready to use")
|
||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|||||||
27
client/iface/device/device_netstack_test.go
Normal file
27
client/iface/device/device_netstack_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewNetstackDevice(t *testing.T) {
|
||||||
|
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
|
||||||
|
|
||||||
|
relayBind := bind.NewRelayBindJS()
|
||||||
|
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
|
||||||
|
|
||||||
|
cfgr, err := nsTun.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create netstack device: %v", err)
|
||||||
|
}
|
||||||
|
if cfgr == nil {
|
||||||
|
t.Fatal("expected non-nil configurer")
|
||||||
|
}
|
||||||
|
}
|
||||||
6
client/iface/iface_destroy_js.go
Normal file
6
client/iface/iface_destroy_js.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
// Destroy is a no-op on WASM
|
||||||
|
func (w *WGIface) Destroy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: tun,
|
tun: tun,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|||||||
41
client/iface/iface_new_freebsd.go
Normal file
41
client/iface/iface_new_freebsd.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build freebsd
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if device.ModuleTunIsLoaded() {
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
|
wgIFace.userspaceBind = true
|
||||||
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
|
return wgIFace, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||||
|
}
|
||||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|||||||
27
client/iface/iface_new_js.go
Normal file
27
client/iface/iface_new_js.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
relayBind := bind.NewRelayBindJS()
|
||||||
|
|
||||||
|
wgIface := &WGIface{
|
||||||
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||||
|
userspaceBind: true,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (linux && !android) || freebsd
|
//go:build linux && !android
|
||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: tun,
|
tun: tun,
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
12
client/iface/netstack/env_js.go
Normal file
12
client/iface/netstack/env_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package netstack
|
||||||
|
|
||||||
|
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||||
|
|
||||||
|
// IsEnabled always returns true for js since it's the only mode available
|
||||||
|
func IsEnabled() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenAddr() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
package udpmux
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||||
|
|||||||
@@ -16,28 +16,38 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProxyBind struct {
|
type Bind interface {
|
||||||
Bind *bind.ICEBind
|
SetEndpoint(addr netip.Addr, conn net.Conn)
|
||||||
|
RemoveEndpoint(addr netip.Addr)
|
||||||
|
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
|
||||||
|
}
|
||||||
|
|
||||||
fakeNetIP *netip.AddrPort
|
type ProxyBind struct {
|
||||||
wgBindEndpoint *bind.Endpoint
|
bind Bind
|
||||||
|
|
||||||
|
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
|
||||||
|
wgRelayedEndpoint *bind.Endpoint
|
||||||
|
wgCurrentUsed *bind.Endpoint
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
closeMu sync.Mutex
|
closeMu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
|
||||||
paused bool
|
paused bool
|
||||||
|
pausedCond *sync.Cond
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
closeListener *listener.CloseListener
|
||||||
|
mtu uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
|
||||||
p := &ProxyBind{
|
p := &ProxyBind{
|
||||||
Bind: bind,
|
bind: bind,
|
||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
|
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
mtu: mtu + bufsize.WGBufferOverhead,
|
||||||
}
|
}
|
||||||
|
|
||||||
return p
|
return p
|
||||||
@@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
|||||||
// AddTurnConn adds a new connection to the bind.
|
// AddTurnConn adds a new connection to the bind.
|
||||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||||
// WireGuard configuration.
|
// WireGuard configuration.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
|
||||||
|
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
|
||||||
|
// - remoteConn: The established TURN connection to the remote peer
|
||||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
fakeNetIP, err := fakeAddress(nbAddr)
|
fakeNetIP, err := fakeAddress(nbAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||||
p.fakeNetIP = fakeNetIP
|
|
||||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||||
return &net.UDPAddr{
|
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
|
||||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
|
||||||
Port: int(p.fakeNetIP.Port()),
|
|
||||||
Zone: p.fakeNetIP.Addr().Zone(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||||
@@ -76,17 +86,21 @@ func (p *ProxyBind) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
|
||||||
|
p.wgCurrentUsed = p.wgRelayedEndpoint
|
||||||
|
|
||||||
// Start the proxy only once
|
// Start the proxy only once
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
go p.proxyToLocal(p.ctx)
|
go p.proxyToLocal(p.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) Pause() {
|
func (p *ProxyBind) Pause() {
|
||||||
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = true
|
p.paused = true
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
|
||||||
|
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||||
|
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||||
|
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||||
|
return &bind.Endpoint{AddrPort: addrPort}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) CloseConn() error {
|
func (p *ProxyBind) CloseConn() error {
|
||||||
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) close() error {
|
func (p *ProxyBind) close() error {
|
||||||
|
if p.remoteConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
p.closeMu.Lock()
|
p.closeMu.Lock()
|
||||||
defer p.closeMu.Unlock()
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
|
|||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
|
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
|
||||||
|
|
||||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||||
return rErr
|
return rErr
|
||||||
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
|
buf := make([]byte, p.mtu)
|
||||||
n, err := p.remoteConn.Read(buf)
|
n, err := p.remoteConn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
if p.paused {
|
for p.paused {
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.Wait()
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := bind.RecvMessage{
|
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
|
||||||
Endpoint: p.wgBindEndpoint,
|
p.pausedCond.L.Unlock()
|
||||||
Buffer: buf[:n],
|
|
||||||
}
|
|
||||||
p.Bind.RecvChan <- msg
|
|
||||||
p.pausedMu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@@ -18,15 +16,20 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
loopbackAddr = "127.0.0.1"
|
loopbackAddr = "127.0.0.1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||||
|
)
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
type WGEBPFProxy struct {
|
type WGEBPFProxy struct {
|
||||||
localWGListenPort int
|
localWGListenPort int
|
||||||
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.rawConn, err = p.prepareSenderRawSocket()
|
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -214,57 +217,17 @@ generatePort:
|
|||||||
return p.lastUsedPort, nil
|
return p.lastUsedPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||||
// Create a raw socket.
|
|
||||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
|
||||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind the socket to the "lo" interface.
|
|
||||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
|
||||||
err = nbnet.SetSocketOpt(fd)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert the file descriptor to a PacketConn.
|
|
||||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
|
||||||
if file == nil {
|
|
||||||
return nil, fmt.Errorf("converting fd to file failed")
|
|
||||||
}
|
|
||||||
packetConn, err := net.FilePacketConn(file)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return packetConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
|
||||||
localhost := net.ParseIP("127.0.0.1")
|
|
||||||
|
|
||||||
payload := gopacket.Payload(data)
|
payload := gopacket.Payload(data)
|
||||||
ipH := &layers.IPv4{
|
ipH := &layers.IPv4{
|
||||||
DstIP: localhost,
|
DstIP: localHostNetIP,
|
||||||
SrcIP: localhost,
|
SrcIP: endpointAddr.IP,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
udpH := &layers.UDP{
|
udpH := &layers.UDP{
|
||||||
SrcPort: layers.UDPPort(port),
|
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("serialize layers: %w", err)
|
return fmt.Errorf("serialize layers: %w", err)
|
||||||
}
|
}
|
||||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||||
return fmt.Errorf("write to raw conn: %w", err)
|
return fmt.Errorf("write to raw conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -18,41 +18,42 @@ import (
|
|||||||
|
|
||||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||||
type ProxyWrapper struct {
|
type ProxyWrapper struct {
|
||||||
WgeBPFProxy *WGEBPFProxy
|
wgeBPFProxy *WGEBPFProxy
|
||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
|
||||||
wgEndpointAddr *net.UDPAddr
|
wgRelayedEndpointAddr *net.UDPAddr
|
||||||
|
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
|
||||||
paused bool
|
paused bool
|
||||||
|
pausedCond *sync.Cond
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
closeListener *listener.CloseListener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||||
return &ProxyWrapper{
|
return &ProxyWrapper{
|
||||||
WgeBPFProxy: WgeBPFProxy,
|
wgeBPFProxy: proxy,
|
||||||
|
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add turn conn: %w", err)
|
return fmt.Errorf("add turn conn: %w", err)
|
||||||
}
|
}
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.wgEndpointAddr = addr
|
p.wgRelayedEndpointAddr = addr
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||||
return p.wgEndpointAddr
|
return p.wgRelayedEndpointAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||||
@@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
|
||||||
|
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
go p.proxyToLocal(p.ctx)
|
go p.proxyToLocal(p.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) Pause() {
|
func (p *ProxyWrapper) Pause() {
|
||||||
@@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = true
|
p.paused = true
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
|
||||||
|
p.wgEndpointCurrentUsedAddr = endpoint
|
||||||
|
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||||
func (e *ProxyWrapper) CloseConn() error {
|
func (p *ProxyWrapper) CloseConn() error {
|
||||||
if e.cancel == nil {
|
if p.cancel == nil {
|
||||||
return fmt.Errorf("proxy not started")
|
return fmt.Errorf("proxy not started")
|
||||||
}
|
}
|
||||||
|
|
||||||
e.cancel()
|
p.cancel()
|
||||||
|
|
||||||
e.closeListener.SetCloseListener(nil)
|
p.closeListener.SetCloseListener(nil)
|
||||||
|
|
||||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
p.pausedCond.L.Lock()
|
||||||
return fmt.Errorf("close remote conn: %w", err)
|
p.paused = false
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
|
||||||
|
|
||||||
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||||
for {
|
for {
|
||||||
n, err := p.readFromRemote(ctx, buf)
|
n, err := p.readFromRemote(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
if p.paused {
|
for p.paused {
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.Wait()
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
|||||||
}
|
}
|
||||||
p.closeListener.Notify()
|
p.closeListener.Notify()
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
func (w *KernelFactory) Free() error {
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
package wgproxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// KernelFactory todo: check eBPF support on FreeBSD
|
|
||||||
type KernelFactory struct {
|
|
||||||
wgPort int
|
|
||||||
mtu uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
|
||||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
|
||||||
f := &KernelFactory{
|
|
||||||
wgPort: wgPort,
|
|
||||||
mtu: mtu,
|
|
||||||
}
|
|
||||||
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *KernelFactory) GetProxy() Proxy {
|
|
||||||
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *KernelFactory) Free() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -3,24 +3,25 @@ package wgproxy
|
|||||||
import (
|
import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
)
|
)
|
||||||
|
|
||||||
type USPFactory struct {
|
type USPFactory struct {
|
||||||
bind *bind.ICEBind
|
bind proxyBind.Bind
|
||||||
|
mtu uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
|
||||||
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
||||||
f := &USPFactory{
|
f := &USPFactory{
|
||||||
bind: iceBind,
|
bind: bind,
|
||||||
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) GetProxy() Proxy {
|
func (w *USPFactory) GetProxy() Proxy {
|
||||||
return proxyBind.NewProxyBind(w.bind)
|
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *USPFactory) Free() error {
|
func (w *USPFactory) Free() error {
|
||||||
|
|||||||
@@ -11,6 +11,11 @@ type Proxy interface {
|
|||||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||||
Work() // Work start or resume the proxy
|
Work() // Work start or resume the proxy
|
||||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||||
|
|
||||||
|
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
|
||||||
|
//and rewrite the src address to the endpoint address.
|
||||||
|
//With this logic can avoid the package loss from relayed connections.
|
||||||
|
RedirectAs(endpoint *net.UDPAddr)
|
||||||
CloseConn() error
|
CloseConn() error
|
||||||
SetDisconnectListener(disconnected func())
|
SetDisconnectListener(disconnected func())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,54 +3,82 @@
|
|||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"fmt"
|
||||||
"os"
|
"net"
|
||||||
"testing"
|
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
func seedProxies() ([]proxyInstance, error) {
|
||||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
pl := make([]proxyInstance, 0)
|
||||||
t.Skip("Skipping test as it requires root privileges")
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
pEbpf := proxyInstance{
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
name: "ebpf kernel proxy",
|
||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||||
|
wgPort: 51831,
|
||||||
|
closeFn: ebpfProxy.Free,
|
||||||
}
|
}
|
||||||
}()
|
pl = append(pl, pEbpf)
|
||||||
|
|
||||||
tests := []struct {
|
pUDP := proxyInstance{
|
||||||
name string
|
name: "udp kernel proxy",
|
||||||
proxy Proxy
|
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||||
}{
|
wgPort: 51832,
|
||||||
{
|
closeFn: func() error { return nil },
|
||||||
name: "ebpf proxy",
|
}
|
||||||
proxy: &ebpf.ProxyWrapper{
|
pl = append(pl, pUDP)
|
||||||
WgeBPFProxy: ebpfProxy,
|
return pl, nil
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
pl := make([]proxyInstance, 0)
|
||||||
relayedConn := newMockConn()
|
|
||||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||||
|
if err := ebpfProxy.Listen(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pEbpf := proxyInstance{
|
||||||
|
name: "ebpf kernel proxy",
|
||||||
|
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||||
|
wgPort: 51831,
|
||||||
|
closeFn: ebpfProxy.Free,
|
||||||
|
}
|
||||||
|
pl = append(pl, pEbpf)
|
||||||
|
|
||||||
|
pUDP := proxyInstance{
|
||||||
|
name: "udp kernel proxy",
|
||||||
|
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||||
|
wgPort: 51832,
|
||||||
|
closeFn: func() error { return nil },
|
||||||
|
}
|
||||||
|
pl = append(pl, pUDP)
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
return nil, err
|
||||||
|
}
|
||||||
|
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||||
|
endpointAddress := &net.UDPAddr{
|
||||||
|
IP: net.IPv4(10, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = relayedConn.Close()
|
pBind := proxyInstance{
|
||||||
if err := tt.proxy.CloseConn(); err != nil {
|
name: "bind proxy",
|
||||||
t.Errorf("error: %v", err)
|
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||||
}
|
endpointAddr: endpointAddress,
|
||||||
})
|
closeFn: func() error { return nil },
|
||||||
}
|
}
|
||||||
|
pl = append(pl, pBind)
|
||||||
|
|
||||||
|
return pl, nil
|
||||||
}
|
}
|
||||||
|
|||||||
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package wgproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||||
|
)
|
||||||
|
|
||||||
|
func seedProxies() ([]proxyInstance, error) {
|
||||||
|
// todo extend with Bind proxy
|
||||||
|
pl := make([]proxyInstance, 0)
|
||||||
|
return pl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||||
|
pl := make([]proxyInstance, 0)
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||||
|
endpointAddress := &net.UDPAddr{
|
||||||
|
IP: net.IPv4(10, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
|
||||||
|
pBind := proxyInstance{
|
||||||
|
name: "bind proxy",
|
||||||
|
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||||
|
endpointAddr: endpointAddress,
|
||||||
|
closeFn: func() error { return nil },
|
||||||
|
}
|
||||||
|
pl = append(pl, pBind)
|
||||||
|
return pl, nil
|
||||||
|
}
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package wgproxy
|
package wgproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,12 +5,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
|
||||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
|
|||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type proxyInstance struct {
|
||||||
|
name string
|
||||||
|
proxy Proxy
|
||||||
|
wgPort int
|
||||||
|
endpointAddr *net.UDPAddr
|
||||||
|
closeFn func() error
|
||||||
|
}
|
||||||
|
|
||||||
type mocConn struct {
|
type mocConn struct {
|
||||||
closeChan chan struct{}
|
closeChan chan struct{}
|
||||||
closed bool
|
closed bool
|
||||||
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
|||||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
tests := []struct {
|
tests, err := seedProxyForProxyCloseByRemoteConn()
|
||||||
name string
|
if err != nil {
|
||||||
proxy Proxy
|
t.Fatalf("error: %v", err)
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "userspace proxy",
|
|
||||||
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
|
||||||
if err := ebpfProxy.Listen(); err != nil {
|
|
||||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := ebpfProxy.Free(); err != nil {
|
_ = relayedConn.Close()
|
||||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
|
||||||
|
|
||||||
tests = append(tests, struct {
|
|
||||||
name string
|
|
||||||
proxy Proxy
|
|
||||||
}{
|
|
||||||
name: "ebpf proxy",
|
|
||||||
proxy: proxyWrapper,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
|
||||||
relayedConn := newMockConn()
|
relayedConn := newMockConn()
|
||||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error: %v", err)
|
t.Errorf("error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestProxyRedirect todo extend the proxies with Bind proxy
|
||||||
|
func TestProxyRedirect(t *testing.T) {
|
||||||
|
tests, err := seedProxies()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
|
||||||
|
if err := tt.closeFn(); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
msgHelloFromRelay := []byte("hello from relay")
|
||||||
|
msgRedirected := [][]byte{
|
||||||
|
[]byte("hello 1. to p2p"),
|
||||||
|
[]byte("hello 2. to p2p"),
|
||||||
|
[]byte("hello 3. to p2p"),
|
||||||
|
}
|
||||||
|
|
||||||
|
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: wgPort})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to listen on udp port: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayedServer, _ := net.ListenUDP("udp",
|
||||||
|
&net.UDPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = dummyWgListener.Close()
|
||||||
|
_ = relayedConn.Close()
|
||||||
|
_ = relayedServer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := proxy.CloseConn(); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy.Work()
|
||||||
|
|
||||||
|
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
|
||||||
|
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dummyWgListener.Read(make([]byte, 1024))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n != len(msgHelloFromRelay) {
|
||||||
|
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
p2pEndpointAddr := &net.UDPAddr{
|
||||||
|
IP: net.IPv4(192, 168, 0, 56),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
proxy.RedirectAs(p2pEndpointAddr)
|
||||||
|
|
||||||
|
for _, msg := range msgRedirected {
|
||||||
|
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(msgRedirected); i++ {
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, rAddr, err := dummyWgListener.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rAddr.String() != p2pEndpointAddr.String() {
|
||||||
|
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
|
||||||
|
}
|
||||||
|
if string(buf[:n]) != string(msgRedirected[i]) {
|
||||||
|
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package rawsocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||||
|
// Create a raw socket.
|
||||||
|
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||||
|
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind the socket to the "lo" interface.
|
||||||
|
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the fwmark on the socket.
|
||||||
|
err = nbnet.SetSocketOpt(fd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the file descriptor to a PacketConn.
|
||||||
|
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||||
|
if file == nil {
|
||||||
|
return nil, fmt.Errorf("converting fd to file failed")
|
||||||
|
}
|
||||||
|
packetConn, err := net.FilePacketConn(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return packetConn, nil
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
package udp
|
package udp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -23,13 +25,15 @@ type WGUDPProxy struct {
|
|||||||
|
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
localConn net.Conn
|
localConn net.Conn
|
||||||
|
srcFakerConn *SrcFaker
|
||||||
|
sendPkg func(data []byte) (int, error)
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
closeMu sync.Mutex
|
closeMu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
|
||||||
paused bool
|
paused bool
|
||||||
|
pausedCond *sync.Cond
|
||||||
isStarted bool
|
isStarted bool
|
||||||
|
|
||||||
closeListener *listener.CloseListener
|
closeListener *listener.CloseListener
|
||||||
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
|||||||
p := &WGUDPProxy{
|
p := &WGUDPProxy{
|
||||||
localWGListenPort: wgPort,
|
localWGListenPort: wgPort,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
|
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||||
closeListener: listener.NewCloseListener(),
|
closeListener: listener.NewCloseListener(),
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
|
|||||||
|
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
p.localConn = localConn
|
p.localConn = localConn
|
||||||
|
p.sendPkg = p.localConn.Write
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
p.sendPkg = p.localConn.Write
|
||||||
|
|
||||||
|
if p.srcFakerConn != nil {
|
||||||
|
if err := p.srcFakerConn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close src faker conn: %s", err)
|
||||||
|
}
|
||||||
|
p.srcFakerConn = nil
|
||||||
|
}
|
||||||
|
|
||||||
if !p.isStarted {
|
if !p.isStarted {
|
||||||
p.isStarted = true
|
p.isStarted = true
|
||||||
go p.proxyToRemote(p.ctx)
|
go p.proxyToRemote(p.ctx)
|
||||||
go p.proxyToLocal(p.ctx)
|
go p.proxyToLocal(p.ctx)
|
||||||
}
|
}
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pause pauses the proxy from receiving data from the remote peer
|
// Pause pauses the proxy from receiving data from the remote peer
|
||||||
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
p.paused = true
|
p.paused = true
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.L.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectAs start to use the fake sourced raw socket as package sender
|
||||||
|
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||||
|
p.pausedCond.L.Lock()
|
||||||
|
defer func() {
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
p.paused = false
|
||||||
|
if p.srcFakerConn != nil {
|
||||||
|
if err := p.srcFakerConn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close src faker conn: %s", err)
|
||||||
|
}
|
||||||
|
p.srcFakerConn = nil
|
||||||
|
}
|
||||||
|
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create src faker conn: %s", err)
|
||||||
|
// fallback to continue without redirecting
|
||||||
|
p.paused = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.srcFakerConn = srcFakerConn
|
||||||
|
p.sendPkg = p.srcFakerConn.SendPkg
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseConn close the localConn
|
// CloseConn close the localConn
|
||||||
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *WGUDPProxy) close() error {
|
func (p *WGUDPProxy) close() error {
|
||||||
|
var result *multierror.Error
|
||||||
|
|
||||||
p.closeMu.Lock()
|
p.closeMu.Lock()
|
||||||
defer p.closeMu.Unlock()
|
defer p.closeMu.Unlock()
|
||||||
|
|
||||||
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
|
|||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
var result *multierror.Error
|
p.pausedCond.L.Lock()
|
||||||
|
p.paused = false
|
||||||
|
p.pausedCond.Signal()
|
||||||
|
p.pausedCond.L.Unlock()
|
||||||
|
|
||||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||||
}
|
}
|
||||||
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
|
|||||||
if err := p.localConn.Close(); err != nil {
|
if err := p.localConn.Close(); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.srcFakerConn != nil {
|
||||||
|
if err := p.srcFakerConn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return cerrors.FormatErrorOrNil(result)
|
return cerrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedCond.L.Lock()
|
||||||
if p.paused {
|
for p.paused {
|
||||||
p.pausedMu.Unlock()
|
p.pausedCond.Wait()
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
_, err = p.sendPkg(buf[:n])
|
||||||
_, err = p.localConn.Write(buf[:n])
|
p.pausedCond.L.Unlock()
|
||||||
p.pausedMu.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
|
|||||||
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serializeOpts = gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
localHostNetIPAddr = &net.IPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type SrcFaker struct {
|
||||||
|
srcAddr *net.UDPAddr
|
||||||
|
|
||||||
|
rawSocket net.PacketConn
|
||||||
|
ipH gopacket.SerializableLayer
|
||||||
|
udpH gopacket.SerializableLayer
|
||||||
|
layerBuffer gopacket.SerializeBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||||
|
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f := &SrcFaker{
|
||||||
|
srcAddr: srcAddr,
|
||||||
|
rawSocket: rawSocket,
|
||||||
|
ipH: ipH,
|
||||||
|
udpH: udpH,
|
||||||
|
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *SrcFaker) Close() error {
|
||||||
|
return f.rawSocket.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||||
|
defer func() {
|
||||||
|
if err := f.layerBuffer.Clear(); err != nil {
|
||||||
|
log.Errorf("failed to clear layer buffer: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := gopacket.Payload(data)
|
||||||
|
|
||||||
|
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||||
|
}
|
||||||
|
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||||
|
ipH := &layers.IPv4{
|
||||||
|
DstIP: net.ParseIP("127.0.0.1"),
|
||||||
|
SrcIP: srcAddr.IP,
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
udpH := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||||
|
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||||
|
}
|
||||||
|
|
||||||
|
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipH, udpH, nil
|
||||||
|
}
|
||||||
@@ -34,7 +34,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
for i, domain := range domains {
|
for i, domain := range domains {
|
||||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||||
if r.gpo {
|
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
|
||||||
}
|
|
||||||
|
|
||||||
singleDomain := []string{domain}
|
singleDomain := []string{domain}
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.gpo {
|
||||||
|
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||||
|
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||||
@@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
|||||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||||
}
|
}
|
||||||
@@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
|||||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||||
}
|
}
|
||||||
|
|||||||
5
client/internal/dns/server_js.go
Normal file
5
client/internal/dns/server_js.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||||
|
return &noopHostConfigurator{}, nil
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceViaMemory struct {
|
type ServiceViaMemory struct {
|
||||||
|
|||||||
19
client/internal/dns/unclean_shutdown_js.go
Normal file
19
client/internal/dns/unclean_shutdown_js.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ShutdownState struct{}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -11,13 +12,17 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||||
|
listenPort uint16 = 5353
|
||||||
|
listenPortMu sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
|
||||||
ListenPort = 5353
|
|
||||||
dnsTTL = 60 //seconds
|
dnsTTL = 60 //seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,12 +40,20 @@ type Manager struct {
|
|||||||
fwRules []firewall.Rule
|
fwRules []firewall.Rule
|
||||||
tcpRules []firewall.Rule
|
tcpRules []firewall.Rule
|
||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
|
port uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
func ListenPort() uint16 {
|
||||||
|
listenPortMu.RLock()
|
||||||
|
defer listenPortMu.RUnlock()
|
||||||
|
return listenPort
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
firewall: fw,
|
firewall: fw,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
port: port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
|
if m.port > 0 {
|
||||||
|
listenPortMu.Lock()
|
||||||
|
listenPort = m.port
|
||||||
|
listenPortMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||||
go func() {
|
go func() {
|
||||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||||
// todo handle close error if it is exists
|
// todo handle close error if it is exists
|
||||||
@@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
|||||||
func (m *Manager) allowDNSFirewall() error {
|
func (m *Manager) allowDNSFirewall() error {
|
||||||
dport := &firewall.Port{
|
dport := &firewall.Port{
|
||||||
IsRange: false,
|
IsRange: false,
|
||||||
Values: []uint16{ListenPort},
|
Values: []uint16{ListenPort()},
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.firewall == nil {
|
if m.firewall == nil {
|
||||||
|
|||||||
@@ -198,6 +198,13 @@ type Engine struct {
|
|||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
|
// WireGuard interface monitor
|
||||||
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
|
wgIfaceMonitorWg sync.WaitGroup
|
||||||
|
|
||||||
|
// dns forwarder port
|
||||||
|
dnsFwdPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -240,6 +247,7 @@ func NewEngine(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
|
dnsFwdPort: dnsfwd.ListenPort(),
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
@@ -341,6 +349,9 @@ func (e *Engine) Stop() error {
|
|||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop WireGuard interface monitor and wait for it to exit
|
||||||
|
e.wgIfaceMonitorWg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,14 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return fmt.Errorf("initialize dns server: %w", err)
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iceCfg := icemaker.Config{
|
iceCfg := e.createICEConfig()
|
||||||
StunTurn: &e.stunTurn,
|
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
|
||||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
|
||||||
UDPMuxSrflx: e.udpMux,
|
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
|
||||||
}
|
|
||||||
|
|
||||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||||
e.connMgr.Start(e.ctx)
|
e.connMgr.Start(e.ctx)
|
||||||
@@ -477,6 +481,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
|
|
||||||
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
|
e.wgIfaceMonitorWg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer e.wgIfaceMonitorWg.Done()
|
||||||
|
|
||||||
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||||
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
|
e.restartEngine()
|
||||||
|
} else if err != nil {
|
||||||
|
log.Warnf("WireGuard interface monitor: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1064,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||||
|
|
||||||
// Ingress forward rules
|
// Ingress forward rules
|
||||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||||
@@ -1322,14 +1342,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
Addr: e.getRosenpassAddr(),
|
Addr: e.getRosenpassAddr(),
|
||||||
PermissiveMode: e.config.RosenpassPermissive,
|
PermissiveMode: e.config.RosenpassPermissive,
|
||||||
},
|
},
|
||||||
ICEConfig: icemaker.Config{
|
ICEConfig: e.createICEConfig(),
|
||||||
StunTurn: &e.stunTurn,
|
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
|
||||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
|
||||||
UDPMuxSrflx: e.udpMux,
|
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceDependencies := peer.ServiceDependencies{
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
@@ -1830,6 +1843,7 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
|||||||
func (e *Engine) updateDNSForwarder(
|
func (e *Engine) updateDNSForwarder(
|
||||||
enabled bool,
|
enabled bool,
|
||||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||||
|
forwarderPort uint16,
|
||||||
) {
|
) {
|
||||||
if e.config.DisableServerRoutes {
|
if e.config.DisableServerRoutes {
|
||||||
return
|
return
|
||||||
@@ -1846,16 +1860,20 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(fwdEntries) > 0 {
|
if len(fwdEntries) > 0 {
|
||||||
if e.dnsForwardMgr == nil {
|
switch {
|
||||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
case e.dnsForwardMgr == nil:
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
log.Errorf("failed to start DNS forward: %v", err)
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||||
} else {
|
case e.dnsFwdPort != forwarderPort:
|
||||||
|
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||||
|
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||||
|
e.dnsFwdPort = forwarderPort
|
||||||
|
|
||||||
|
default:
|
||||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||||
}
|
}
|
||||||
} else if e.dnsForwardMgr != nil {
|
} else if e.dnsForwardMgr != nil {
|
||||||
@@ -1865,6 +1883,20 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
}
|
}
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||||
|
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||||
|
// stop and start the forwarder to apply the new port
|
||||||
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
|
}
|
||||||
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||||
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
|
e.dnsForwardMgr = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||||
|
|||||||
19
client/internal/engine_generic.go
Normal file
19
client/internal/engine_generic.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createICEConfig creates ICE configuration for non-WASM environments
|
||||||
|
func (e *Engine) createICEConfig() icemaker.Config {
|
||||||
|
return icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||||
|
UDPMuxSrflx: e.udpMux,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
}
|
||||||
18
client/internal/engine_js.go
Normal file
18
client/internal/engine_js.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createICEConfig creates ICE configuration for WASM environment.
|
||||||
|
func (e *Engine) createICEConfig() icemaker.Config {
|
||||||
|
cfg := icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
@@ -27,6 +27,10 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -42,10 +46,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/peers"
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
@@ -1584,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/ti-mo/netfilter"
|
"github.com/ti-mo/netfilter"
|
||||||
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultChannelSize = 100
|
const defaultChannelSize = 100
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
|||||||
|
|
||||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||||
// check dns collection
|
// check dns collection
|
||||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
12
client/internal/networkmonitor/check_change_js.go
Normal file
12
client/internal/networkmonitor/check_change_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
// No-op for WASM - network changes don't apply
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -28,10 +28,6 @@ import (
|
|||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
defaultWgKeepAlive = 25 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServiceDependencies struct {
|
type ServiceDependencies struct {
|
||||||
StatusRecorder *Status
|
StatusRecorder *Status
|
||||||
Signaler *Signaler
|
Signaler *Signaler
|
||||||
@@ -117,6 +113,8 @@ type Conn struct {
|
|||||||
|
|
||||||
// debug purpose
|
// debug purpose
|
||||||
dumpState *stateDump
|
dumpState *stateDump
|
||||||
|
|
||||||
|
endpointUpdater *EndpointUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@@ -140,6 +138,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
|||||||
statusRelay: worker.NewAtomicStatus(),
|
statusRelay: worker.NewAtomicStatus(),
|
||||||
statusICE: worker.NewAtomicStatus(),
|
statusICE: worker.NewAtomicStatus(),
|
||||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||||
|
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.wgProxyICE = nil
|
conn.wgProxyICE = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.removeWgPeer(); err != nil {
|
if err := conn.endpointUpdater.RemoveWgPeer(); err != nil {
|
||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||||
|
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||||
|
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||||
conn.handleConfigurationFailure(err, wgProxy)
|
conn.handleConfigurationFailure(err, wgProxy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wgConfigWorkaround()
|
wgConfigWorkaround()
|
||||||
|
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
conn.Log.Debugf("redirect packets from relayed conn to WireGuard")
|
||||||
|
conn.wgProxyRelay.RedirectAs(ep)
|
||||||
|
}
|
||||||
|
|
||||||
conn.currentConnPriority = priority
|
conn.currentConnPriority = priority
|
||||||
conn.statusICE.SetConnected()
|
conn.statusICE.SetConnected()
|
||||||
conn.updateIceState(iceConnInfo)
|
conn.updateIceState(iceConnInfo)
|
||||||
@@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
conn.dumpState.SwitchToRelay()
|
conn.dumpState.SwitchToRelay()
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||||
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -418,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
defer conn.wgWatcherWg.Done()
|
defer conn.wgWatcherWg.Done()
|
||||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||||
}()
|
}()
|
||||||
|
conn.wgProxyRelay.Work()
|
||||||
conn.currentConnPriority = conntype.Relay
|
conn.currentConnPriority = conntype.Relay
|
||||||
} else {
|
} else {
|
||||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
@@ -477,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||||
|
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
@@ -545,17 +554,6 @@ func (conn *Conn) onGuardEvent() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
|
|
||||||
presharedKey := conn.presharedKey(remoteRPKey)
|
|
||||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
|
||||||
conn.config.WgConfig.RemoteKey,
|
|
||||||
conn.config.WgConfig.AllowedIps,
|
|
||||||
defaultWgKeepAlive,
|
|
||||||
addr,
|
|
||||||
presharedKey,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||||
peerState := State{
|
peerState := State{
|
||||||
PubKey: conn.config.Key,
|
PubKey: conn.config.Key,
|
||||||
@@ -698,10 +696,6 @@ func (conn *Conn) isICEActive() bool {
|
|||||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) removeWgPeer() error {
|
|
||||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||||
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||||
if wgProxy != nil {
|
if wgProxy != nil {
|
||||||
|
|||||||
105
client/internal/peer/endpoint.go
Normal file
105
client/internal/peer/endpoint.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultWgKeepAlive = 25 * time.Second
|
||||||
|
fallbackDelay = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type EndpointUpdater struct {
|
||||||
|
log *logrus.Entry
|
||||||
|
wgConfig WgConfig
|
||||||
|
initiator bool
|
||||||
|
|
||||||
|
// mu protects updateWireGuardPeer and cancelFunc
|
||||||
|
mu sync.Mutex
|
||||||
|
cancelFunc func()
|
||||||
|
updateWg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater {
|
||||||
|
return &EndpointUpdater{
|
||||||
|
log: log,
|
||||||
|
wgConfig: wgConfig,
|
||||||
|
initiator: initiator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||||
|
// The initiator immediately configures the endpoint, while the non-initiator
|
||||||
|
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||||
|
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
if e.initiator {
|
||||||
|
e.log.Debugf("configure up WireGuard as initiatr")
|
||||||
|
return e.updateWireGuardPeer(addr, presharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent to run new update while cancel the previous update
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||||
|
e.updateWg.Add(1)
|
||||||
|
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||||
|
|
||||||
|
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||||
|
return e.updateWireGuardPeer(nil, presharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
e.waitForCloseTheDelayedUpdate()
|
||||||
|
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||||
|
if e.cancelFunc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.cancelFunc()
|
||||||
|
e.cancelFunc = nil
|
||||||
|
e.updateWg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
|
||||||
|
func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) {
|
||||||
|
defer e.updateWg.Done()
|
||||||
|
t := time.NewTimer(fallbackDelay)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-t.C:
|
||||||
|
e.mu.Lock()
|
||||||
|
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||||
|
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||||
|
}
|
||||||
|
e.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||||
|
return e.wgConfig.WgInterface.UpdatePeer(
|
||||||
|
e.wgConfig.RemoteKey,
|
||||||
|
e.wgConfig.AllowedIps,
|
||||||
|
defaultWgKeepAlive,
|
||||||
|
endpoint,
|
||||||
|
presharedKey,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -3,6 +3,8 @@ package guard
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,7 +26,7 @@ type ICEMonitor struct {
|
|||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
iceConfig icemaker.Config
|
iceConfig icemaker.Config
|
||||||
|
|
||||||
currentCandidates []ice.Candidate
|
currentCandidatesAddress []string
|
||||||
candidatesMu sync.Mutex
|
candidatesMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,17 +117,22 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
|||||||
cm.candidatesMu.Lock()
|
cm.candidatesMu.Lock()
|
||||||
defer cm.candidatesMu.Unlock()
|
defer cm.candidatesMu.Unlock()
|
||||||
|
|
||||||
if len(cm.currentCandidates) != len(newCandidates) {
|
newAddresses := make([]string, len(newCandidates))
|
||||||
cm.currentCandidates = newCandidates
|
for i, c := range newCandidates {
|
||||||
|
newAddresses[i] = c.Address()
|
||||||
|
}
|
||||||
|
sort.Strings(newAddresses)
|
||||||
|
|
||||||
|
if len(cm.currentCandidatesAddress) != len(newAddresses) {
|
||||||
|
cm.currentCandidatesAddress = newAddresses
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, candidate := range cm.currentCandidates {
|
// Compare elements
|
||||||
if candidate.Address() != newCandidates[i].Address() {
|
if !slices.Equal(cm.currentCandidatesAddress, newAddresses) {
|
||||||
cm.currentCandidates = newCandidates
|
cm.currentCandidatesAddress = newAddresses
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -218,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
|
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
|
||||||
|
w.onICESelectedCandidatePair(agent, c1, c2)
|
||||||
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,26 +367,17 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||||
w.config.Key)
|
w.config.Key)
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
pairStat, ok := agent.GetSelectedCandidatePairStats()
|
||||||
|
if !ok {
|
||||||
pair, err := w.agent.GetSelectedCandidatePair()
|
w.log.Warnf("failed to get selected candidate pair stats")
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to get selected candidate pair: %s", err)
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if pair == nil {
|
|
||||||
w.log.Warnf("selected candidate pair is nil, cannot proceed")
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
|
|
||||||
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
|
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
|
||||||
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
|
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
|
||||||
w.log.Debugf("failed to update latency for peer: %s", err)
|
w.log.Debugf("failed to update latency for peer: %s", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProbeResult holds the info about the result of a relay probe request
|
// ProbeResult holds the info about the result of a relay probe request
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dnsTimeout = 8 * time.Second
|
const dnsTimeout = 8 * time.Second
|
||||||
@@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||||
|
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||||
|
}
|
||||||
|
|
||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||||
return fmt.Errorf("setup routing: %w", err)
|
return fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Routing cleanup complete")
|
log.Info("Routing cleanup complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
nbnet.SetVPNInterfaceName("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -22,7 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const localSubnetsCacheTTL = 15 * time.Minute
|
const localSubnetsCacheTTL = 15 * time.Minute
|
||||||
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove hooks selectively
|
hooks.RemoveWriteHooks()
|
||||||
nbnet.RemoveDialerHooks()
|
hooks.RemoveCloseHooks()
|
||||||
nbnet.RemoveListenerHooks()
|
hooks.RemoveAddressRemoveHooks()
|
||||||
|
|
||||||
if err := r.refCounter.Flush(); err != nil {
|
if err := r.refCounter.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush route manager: %w", err)
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||||
return fmt.Errorf("adding route reference: %v", err)
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
}
|
}
|
||||||
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
afterHook := func(connID hooks.ConnectionID) error {
|
||||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
for _, ip := range initAddresses {
|
||||||
if err := beforeHook("init", ip); err != nil {
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := beforeHook("init", prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
hooks.AddWriteHook(beforeHook)
|
||||||
if ctx.Err() != nil {
|
hooks.AddCloseHook(afterHook)
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
for _, ip := range resolvedIPs {
|
|
||||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
||||||
return beforeHook(connID, ip.IP)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
|
||||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
type dialer interface {
|
||||||
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||||
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
index, err := net.InterfaceByName(wgInterface.Name())
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
r.prefixes = make(map[netip.Prefix]struct{})
|
r.prefixes = make(map[netip.Prefix]struct{})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrRouteNotSupported = errors.New("route operations not supported on js")
|
||||||
|
|
||||||
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
|
return []netip.Prefix{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
return []netip.Prefix{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||||
|
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||||
|
return []DetailedRoute{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPRule contains IP rule information for debugging
|
// IPRule contains IP rule information for debugging
|
||||||
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
|
||||||
if !nbnet.AdvancedRouting() {
|
if !advancedRouting {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
if !nbnet.AdvancedRouting() {
|
if !advancedRouting {
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !ios
|
//go:build !linux && !ios && !js
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
|
|||||||
@@ -20,11 +20,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketExpectation struct {
|
type PacketExpectation struct {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -19,9 +20,16 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const InfiniteLifetime = 0xffffffff
|
func init() {
|
||||||
|
nbnet.GetBestInterfaceFunc = GetBestInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
InfiniteLifetime = 0xffffffff
|
||||||
|
)
|
||||||
|
|
||||||
type RouteUpdateType int
|
type RouteUpdateType int
|
||||||
|
|
||||||
@@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct {
|
|||||||
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
|
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// candidateRoute represents a potential route for selection during route lookup
|
||||||
|
type candidateRoute struct {
|
||||||
|
interfaceIndex uint32
|
||||||
|
prefixLength uint8
|
||||||
|
routeMetric uint32
|
||||||
|
interfaceMetric int
|
||||||
|
}
|
||||||
|
|
||||||
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
|
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
|
||||||
type IP_ADDRESS_PREFIX struct {
|
type IP_ADDRESS_PREFIX struct {
|
||||||
Prefix SOCKADDR_INET
|
Prefix SOCKADDR_INET
|
||||||
@@ -177,11 +193,20 @@ const (
|
|||||||
RouteDeleted
|
RouteDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
|
if advancedRouting {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Using legacy routing setup with ref counters")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
|
if advancedRouting {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
|
|||||||
|
|
||||||
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
|
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
|
||||||
if table != nil {
|
if table != nil {
|
||||||
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
||||||
if ret != 0 {
|
|
||||||
log.Warnf("FreeMibTable failed with return code: %d", ret)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
|
|||||||
entryPtr := basePtr + uintptr(i)*entrySize
|
entryPtr := basePtr + uintptr(i)*entrySize
|
||||||
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||||
|
|
||||||
detailed := buildWindowsDetailedRoute(entry)
|
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
|
||||||
if detailed != nil {
|
|
||||||
detailedRoutes = append(detailedRoutes, *detailed)
|
detailedRoutes = append(detailedRoutes, *detailed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
|
|||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
|
||||||
|
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
|
||||||
|
var candidates []candidateRoute
|
||||||
|
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
|
||||||
|
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
|
||||||
|
|
||||||
|
for i := uint32(0); i < table.NumEntries; i++ {
|
||||||
|
entryPtr := basePtr + uintptr(i)*entrySize
|
||||||
|
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||||
|
|
||||||
|
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
|
||||||
|
candidates = append(candidates, *candidate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
|
||||||
|
// Returns nil if the route doesn't match the destination or should be skipped
|
||||||
|
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
|
||||||
|
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
|
||||||
|
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
|
||||||
|
|
||||||
|
return &candidateRoute{
|
||||||
|
interfaceIndex: entry.InterfaceIndex,
|
||||||
|
prefixLength: entry.DestinationPrefix.PrefixLength,
|
||||||
|
routeMetric: entry.Metric,
|
||||||
|
interfaceMetric: interfaceMetric,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getInterfaceMetric retrieves the interface metric for a given interface and address family
|
// getInterfaceMetric retrieves the interface metric for a given interface and address family
|
||||||
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
||||||
if interfaceIndex == 0 {
|
if interfaceIndex == 0 {
|
||||||
@@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
|||||||
return int(ipInterfaceRow.Metric)
|
return int(ipInterfaceRow.Metric)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
|
||||||
|
func sortRouteCandidates(candidates []candidateRoute) {
|
||||||
|
sort.Slice(candidates, func(i, j int) bool {
|
||||||
|
if candidates[i].prefixLength != candidates[j].prefixLength {
|
||||||
|
return candidates[i].prefixLength > candidates[j].prefixLength
|
||||||
|
}
|
||||||
|
if candidates[i].routeMetric != candidates[j].routeMetric {
|
||||||
|
return candidates[i].routeMetric < candidates[j].routeMetric
|
||||||
|
}
|
||||||
|
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBestInterface finds the best interface for reaching a destination,
|
||||||
|
// excluding the VPN interface to avoid routing loops.
|
||||||
|
//
|
||||||
|
// Route selection priority:
|
||||||
|
// 1. Longest prefix match (most specific route)
|
||||||
|
// 2. Lowest route metric
|
||||||
|
// 3. Lowest interface metric
|
||||||
|
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
|
||||||
|
var skipInterfaceIndex int
|
||||||
|
if vpnIntf != "" {
|
||||||
|
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
|
||||||
|
skipInterfaceIndex = iface.Index
|
||||||
|
} else {
|
||||||
|
// not critical, if we cannot get ahold of the interface then we won't need to skip it
|
||||||
|
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := getWindowsRoutingTable()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get routing table: %w", err)
|
||||||
|
}
|
||||||
|
defer freeWindowsRoutingTable(table)
|
||||||
|
|
||||||
|
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
|
||||||
|
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, fmt.Errorf("no route to %s", dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort routes: prefix length -> route metric -> interface metric
|
||||||
|
sortRouteCandidates(candidates)
|
||||||
|
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
log.Debugf("interface %s is down, trying next route", iface.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
|
||||||
|
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
|
||||||
|
return iface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("no usable interface found for %s", dest)
|
||||||
|
}
|
||||||
|
|
||||||
// formatRouteAge formats the route age in seconds to a human-readable string
|
// formatRouteAge formats the route age in seconds to a human-readable string
|
||||||
func formatRouteAge(ageSeconds uint32) string {
|
func formatRouteAge(ageSeconds uint32) string {
|
||||||
if ageSeconds == 0 {
|
if ageSeconds == 0 {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
addr = addr.Unmap()
|
addr = addr.Unmap()
|
||||||
|
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||||
var prefixLength int
|
|
||||||
switch {
|
|
||||||
case addr.Is4():
|
|
||||||
prefixLength = 32
|
|
||||||
case addr.Is6():
|
|
||||||
prefixLength = 128
|
|
||||||
default:
|
|
||||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
|
||||||
return prefix, nil
|
return prefix, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dial connects to the address on the named network.
|
// Dial connects to the address on the named network.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ListenPacket listens for incoming packets on the given network and address.
|
// ListenPacket listens for incoming packets on the given network and address.
|
||||||
|
|||||||
98
client/internal/wg_iface_monitor.go
Normal file
98
client/internal/wg_iface_monitor.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||||
|
// if the interface is deleted externally while the engine is running.
|
||||||
|
type WGIfaceMonitor struct {
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
|
||||||
|
func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||||
|
return &WGIfaceMonitor{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins monitoring the WireGuard interface.
|
||||||
|
// It relies on the provided context cancellation to stop.
|
||||||
|
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||||
|
defer close(m.done)
|
||||||
|
|
||||||
|
// Skip on mobile platforms as they handle interface lifecycle differently
|
||||||
|
if runtime.GOOS == "android" || runtime.GOOS == "ios" {
|
||||||
|
log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
|
||||||
|
return false, errors.New("not supported on mobile platforms")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ifaceName == "" {
|
||||||
|
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||||
|
return false, errors.New("empty interface name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get initial interface index to track the specific interface instance
|
||||||
|
expectedIndex, err := getInterfaceIndex(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
|
||||||
|
return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||||
|
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
// Interface was deleted
|
||||||
|
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||||
|
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if interface index changed (interface was recreated)
|
||||||
|
if currentIndex != expectedIndex {
|
||||||
|
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||||
|
ifaceName, expectedIndex, currentIndex)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInterfaceIndex returns the index of a network interface by name.
|
||||||
|
// Returns an error if the interface is not found.
|
||||||
|
func getInterfaceIndex(name string) (int, error) {
|
||||||
|
if name == "" {
|
||||||
|
return 0, fmt.Errorf("empty interface name")
|
||||||
|
}
|
||||||
|
ifi, err := net.InterfaceByName(name)
|
||||||
|
if err != nil {
|
||||||
|
// Check if it's specifically a "not found" error
|
||||||
|
if errors.Is(err, &net.OpError{}) {
|
||||||
|
// On some systems, this might be a "not found" error
|
||||||
|
return 0, fmt.Errorf("interface not found: %w", err)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("failed to lookup interface: %w", err)
|
||||||
|
}
|
||||||
|
if ifi == nil {
|
||||||
|
return 0, fmt.Errorf("interface not found")
|
||||||
|
}
|
||||||
|
return ifi.Index, nil
|
||||||
|
}
|
||||||
49
client/net/conn.go
Normal file
49
client/net/conn.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn wraps a net.Conn to override the Close method
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
ID hooks.ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
return closeConn(c.ID, c.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
|
||||||
|
type TCPConn struct {
|
||||||
|
*net.TCPConn
|
||||||
|
ID hooks.ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *TCPConn) Close() error {
|
||||||
|
return closeConn(c.ID, c.TCPConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConn is a helper function to close connections and execute close hooks.
|
||||||
|
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||||
|
err := conn.Close()
|
||||||
|
|
||||||
|
closeHooks := hooks.GetCloseHooks()
|
||||||
|
for _, hook := range closeHooks {
|
||||||
|
if err := hook(id); err != nil {
|
||||||
|
log.Errorf("Error executing close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
82
client/net/dial.go
Normal file
82
client/net/dial.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := conn.(type) {
|
||||||
|
case *net.UDPConn:
|
||||||
|
// Advanced routing: plain connection
|
||||||
|
return c, nil
|
||||||
|
case *Conn:
|
||||||
|
// Legacy routing: wrapped connection preserves close hooks
|
||||||
|
udpConn, ok := c.Conn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
|
||||||
|
}
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := conn.(type) {
|
||||||
|
case *net.TCPConn:
|
||||||
|
// Advanced routing: plain connection
|
||||||
|
return c, nil
|
||||||
|
case *Conn:
|
||||||
|
// Legacy routing: wrapped connection preserves close hooks
|
||||||
|
tcpConn, ok := c.Conn.(*net.TCPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
|
||||||
|
}
|
||||||
|
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user