diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index f7b4e238f..ba36c013b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 7e6583cc6..2845b05a5 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros skip: go.mod,go.sum golangci: strategy: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7be52259b..e9741f541 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.22" + SIGN_PIPE_VER: "v0.0.23" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml new file mode 100644 index 000000000..e4ac799bc --- /dev/null +++ b/.github/workflows/wasm-build-validation.yml @@ -0,0 +1,67 @@ +name: Wasm + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + js_lint: + name: "JS / Lint" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + - name: Install golangci-lint + uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc + with: + version: latest + install-mode: binary + skip-cache: true + skip-pkg-cache: true + skip-build-cache: true + - name: Run golangci-lint for WASM + run: | + GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/... + continue-on-error: true + + js_build: + name: "JS / Build" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Build Wasm client + run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd + env: + CGO_ENABLED: 0 + - name: Check Wasm build size + run: | + echo "Wasm build size:" + ls -lh netbird.wasm + + SIZE=$(stat -c%s netbird.wasm) + SIZE_MB=$((SIZE / 1024 / 1024)) + + echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" + + if [ ${SIZE} -gt 52428800 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + exit 1 + fi + diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..e69de29bb diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 59a95c89a..952e946dc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -2,6 +2,18 @@ version: 2 project_name: netbird builds: + - id: netbird-wasm + dir: client/wasm/cmd + binary: netbird + env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0] + goos: + - js + goarch: + - wasm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird dir: client binary: netbird @@ -115,6 +127,11 @@ archives: - builds: - netbird - netbird-static + - id: netbird-wasm + builds: + - netbird-wasm + name_template: "{{ .ProjectName }}_{{ .Version }}" + format: binary nfpms: - maintainer: Netbird diff --git a/README.md b/README.md index ea7655869..2c5ee2ab6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +


@@ -52,7 +53,7 @@ ### Open Source Network Security in a Single Platform -centralized-network-management 1 +https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) diff --git a/client/Dockerfile b/client/Dockerfile index e19a09909..b2f627409 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,7 +18,7 @@ ENV \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_LOGIN_TIMEOUT="5" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index c05246569..d2d0c37f6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,6 +4,7 @@ package android import ( "context" + "os" "slices" "sync" @@ -18,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile @@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } + } +} diff --git a/client/android/env_list.go b/client/android/env_list.go new file mode 100644 index 000000000..04122300a --- /dev/null +++ b/client/android/env_list.go @@ -0,0 +1,32 @@ +package android + +import "github.com/netbirdio/netbird/client/internal/peer" + +var ( + // EnvKeyNBForceRelay Exported for Android java client + EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay +) + +// EnvList wraps a Go map for export to Java +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} diff --git a/client/android/login.go b/client/android/login.go index d8ac645e2..0df78dbc3 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -33,6 +33,7 @@ type ErrListener interface { // the backend want to show an url for the user type URLOpener interface { Open(string) + OnLoginSuccess() } // Auth can register or login new client @@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error { err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) + + if err == nil { + go urlOpener.OnLoginSuccess() + } + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return nil } diff --git a/client/cmd/debug_js.go b/client/cmd/debug_js.go new file mode 100644 index 000000000..d06fb8efc --- /dev/null +++ b/client/cmd/debug_js.go @@ -0,0 +1,8 @@ +package cmd + +import "context" + +// SetupDebugHandler is a no-op for WASM +func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) { + // Debug handler not needed for WASM +} diff --git a/client/cmd/down.go b/client/cmd/down.go index 3ce51c678..17c152d22 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -27,7 +27,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/cmd/root.go b/client/cmd/root.go index 5084bd38a..11e5228f1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*3) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() return grpc.DialContext( diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 99ccb1539..bd3209605 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" @@ -20,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index e686625d6..d047c041e 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{}) + status, err := client.Status(ctx, &proto.StatusRequest{ + WaitForReady: func() *bool { b := true; return &b }(), + }) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index de83f9d96..e918235ed 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -23,23 +23,29 @@ import ( var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientNotStarted = errors.New("client not started") +var ErrConfigNotInitialized = errors.New("config not initialized") -// Client manages a netbird embedded client instance +// Client manages a netbird embedded client instance. type Client struct { deviceName string config *profilemanager.Config mu sync.Mutex cancel context.CancelFunc setupKey string + jwtToken string connect *internal.ConnectClient } -// Options configures a new Client +// Options configures a new Client. type Options struct { // DeviceName is this peer's name in the network DeviceName string // SetupKey is used for authentication SetupKey string + // JWTToken is used for JWT-based authentication + JWTToken string + // PrivateKey is used for direct private key authentication + PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string // PreSharedKey is the pre-shared key for the WireGuard interface @@ -58,8 +64,35 @@ type Options struct { DisableClientRoutes bool } -// New creates a new netbird embedded client +// validateCredentials checks that exactly one credential type is provided +func (opts *Options) validateCredentials() error { + credentialsProvided := 0 + if opts.SetupKey != "" { + credentialsProvided++ + } + if opts.JWTToken != "" { + credentialsProvided++ + } + if opts.PrivateKey != "" { + credentialsProvided++ + } + + if credentialsProvided == 0 { + return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided") + } + if credentialsProvided > 1 { + return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified") + } + + return nil +} + +// New creates a new netbird embedded client. func New(opts Options) (*Client, error) { + if err := opts.validateCredentials(); err != nil { + return nil, err + } + if opts.LogOutput != nil { logrus.SetOutput(opts.LogOutput) } @@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) { return nil, fmt.Errorf("create config: %w", err) } + if opts.PrivateKey != "" { + config.PrivateKey = opts.PrivateKey + } + return &Client{ deviceName: opts.DeviceName, setupKey: opts.SetupKey, + jwtToken: opts.JWTToken, config: config, }, nil } @@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error { ctx := internal.CtxInitState(context.Background()) // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { + if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } @@ -135,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}, 1) + run := make(chan struct{}) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { @@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error { } } +// GetConfig returns a copy of the internal client config. +func (c *Client) GetConfig() (profilemanager.Config, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.config == nil { + return profilemanager.Config{}, ErrConfigNotInitialized + } + return *c.config, nil +} + // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { @@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } -// ListenTCP listens on the given address in the netbird network +// ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { nsnet, addr, err := c.getNet() @@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { return nsnet.ListenTCP(tcpAddr) } -// ListenUDP listens on the given address in the netbird network +// ListenUDP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenUDP(address string) (net.PacketConn, error) { nsnet, addr, err := c.getNet() diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7b90000a8..ed8a7403b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -12,7 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1e44c7a4d..081991235 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // constants needed to manage and create iptable rules diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index e9eeff863..3490c5dad 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,7 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func isIptablesSupported() bool { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 52979d257..9ff5b8c92 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index f8fed4d80..e918d0524 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -22,7 +22,7 @@ import ( nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/util/grpc/dialer.go b/client/grpc/dialer.go similarity index 51% rename from util/grpc/dialer.go rename to client/grpc/dialer.go index f6d6d2f04..54fbb002c 100644 --- a/util/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,15 +4,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "fmt" - "net" - "os/user" "runtime" "time" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -21,35 +15,9 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" ) -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - if runtime.GOOS == "linux" { - currentUser, err := user.Current() - if err != nil { - return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) - } - - // the custom dialer requires root permissions which are not required for use cases run as non-root - if currentUser.Uid != "0" { - log.Debug("Not running as root, using standard dialer") - dialer := &net.Dialer{} - return dialer.DialContext(ctx, "tcp", addr) - } - } - - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) - } - return conn, nil - }) -} - -// grpcDialBackoff is the backoff mechanism for the grpc calls +// Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() b.MaxElapsedTime = 10 * time.Second @@ -57,7 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +// CreateConnection creates a gRPC client connection with the appropriate transport options. +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -67,18 +37,20 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - RootCAs: certPool, + // for js, outer websocket layer takes care of tls verification via WithCustomDialer + InsecureSkipVerify: runtime.GOOS == "js", + RootCAs: certPool, })) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() conn, err := grpc.DialContext( connCtx, addr, transportOption, - WithCustomDialer(), + WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go new file mode 100644 index 000000000..96f347c64 --- /dev/null +++ b/client/grpc/dialer_generic.go @@ -0,0 +1,44 @@ +//go:build !js + +package grpc + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + if runtime.GOOS == "linux" { + currentUser, err := user.Current() + if err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) + } + + // the custom dialer requires root permissions which are not required for use cases run as non-root + if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) + } + } + + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) + } + return conn, nil + }) +} diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go new file mode 100644 index 000000000..b89ec3c21 --- /dev/null +++ b/client/grpc/dialer_js.go @@ -0,0 +1,13 @@ +package grpc + +import ( + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy/client" +) + +// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled, component) +} diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go index 89bddf12c..32b07c330 100644 --- a/client/iface/bind/control.go +++ b/client/iface/bind/control.go @@ -3,7 +3,7 @@ package bind import ( wireguard "golang.zx2c4.com/wireguard/conn" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go index 1926ff88f..caa92f05d 100644 --- a/client/iface/bind/endpoint.go +++ b/client/iface/bind/endpoint.go @@ -1,5 +1,17 @@ package bind -import wgConn "golang.zx2c4.com/wireguard/conn" +import ( + "net" + + wgConn "golang.zx2c4.com/wireguard/conn" +) type Endpoint = wgConn.StdNetEndpoint + +func EndpointToUDPAddr(e Endpoint) *net.UDPAddr { + return &net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + } +} diff --git a/client/iface/bind/error.go b/client/iface/bind/error.go new file mode 100644 index 000000000..db7c23144 --- /dev/null +++ b/client/iface/bind/error.go @@ -0,0 +1,7 @@ +package bind + +import "fmt" + +var ( + ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM") +) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 359d2129b..dfb22ecde 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,9 @@ +//go:build !js + package bind import ( + "context" "encoding/binary" "fmt" "net" @@ -15,15 +18,11 @@ import ( "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) -type RecvMessage struct { - Endpoint *Endpoint - Buffer []byte -} - type receiverCreator struct { iceBind *ICEBind } @@ -41,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - RecvChan chan RecvMessage transportNet transport.Net - filterFn FilterFn - endpoints map[netip.Addr]net.Conn - endpointsMu sync.Mutex + filterFn udpmux.FilterFn + address wgaddr.Address + mtu uint16 + + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + recvChan chan recvMessage // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // new closed channel. With the closedChanMu we can safely close the channel and create a new one - closedChan chan struct{} - closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. - closed bool - - muUDPMux sync.Mutex - udpMux *UniversalUDPMuxDefault - address wgaddr.Address - mtu uint16 + closedChan chan struct{} + closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. + closed bool activityRecorder *ActivityRecorder + + muUDPMux sync.Mutex + udpMux *udpmux.UniversalUDPMuxDefault } -func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - RecvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, + address: address, + mtu: mtu, endpoints: make(map[netip.Addr]net.Conn), + recvChan: make(chan recvMessage, 1), closedChan: make(chan struct{}), closed: true, - mtu: mtu, - address: address, activityRecorder: NewActivityRecorder(), } @@ -82,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad return ib } -func (s *ICEBind) MTU() uint16 { - return s.mtu -} - func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { s.closed = false s.closedChanMu.Lock() @@ -115,7 +111,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder { } // GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { +func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() if s.udpMux == nil { @@ -138,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) { delete(b.endpoints, fakeIP) } +func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-b.closedChan: + return + case <-ctx.Done(): + return + case b.recvChan <- recvMessage{ep, buf}: + } +} + func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { b.endpointsMu.Lock() conn, ok := b.endpoints[ep.DstIP()] @@ -158,8 +164,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, @@ -270,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo select { case <-c.closedChan: return 0, net.ErrClosed - case msg, ok := <-c.RecvChan: + case msg, ok := <-c.recvChan: if !ok { return 0, net.ErrClosed } diff --git a/client/iface/bind/recv_msg.go b/client/iface/bind/recv_msg.go new file mode 100644 index 000000000..65baffaac --- /dev/null +++ b/client/iface/bind/recv_msg.go @@ -0,0 +1,6 @@ +package bind + +type recvMessage struct { + Endpoint *Endpoint + Buffer []byte +} diff --git a/client/iface/bind/relay_bind.go b/client/iface/bind/relay_bind.go new file mode 100644 index 000000000..4c179d6a5 --- /dev/null +++ b/client/iface/bind/relay_bind.go @@ -0,0 +1,125 @@ +package bind + +import ( + "context" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" + + "github.com/netbirdio/netbird/client/iface/udpmux" +) + +// RelayBindJS is a conn.Bind implementation for WebAssembly environments. +// Do not limit to build only js, because we want to be able to run tests +type RelayBindJS struct { + *conn.StdNetBind + + recvChan chan recvMessage + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + activityRecorder *ActivityRecorder + ctx context.Context + cancel context.CancelFunc +} + +func NewRelayBindJS() *RelayBindJS { + return &RelayBindJS{ + recvChan: make(chan recvMessage, 100), + endpoints: make(map[netip.Addr]net.Conn), + activityRecorder: NewActivityRecorder(), + } +} + +// Open creates a receive function for handling relay packets in WASM. +func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { + log.Debugf("Open: creating receive function for port %d", uport) + + s.ctx, s.cancel = context.WithCancel(context.Background()) + + receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + select { + case <-s.ctx.Done(): + return 0, net.ErrClosed + case msg, ok := <-s.recvChan: + if !ok { + return 0, net.ErrClosed + } + copy(bufs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = conn.Endpoint(msg.Endpoint) + return 1, nil + } + } + + log.Debugf("Open: receive function created, returning port %d", uport) + return []conn.ReceiveFunc{receiveFn}, uport, nil +} + +func (s *RelayBindJS) Close() error { + if s.cancel == nil { + return nil + } + log.Debugf("close RelayBindJS") + s.cancel() + return nil +} + +func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-s.ctx.Done(): + return + case <-ctx.Done(): + return + case s.recvChan <- recvMessage{ep, buf}: + } +} + +// Send forwards packets through the relay connection for WASM. +func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error { + if ep == nil { + return nil + } + + fakeIP := ep.DstIP() + + s.endpointsMu.Lock() + relayConn, ok := s.endpoints[fakeIP] + s.endpointsMu.Unlock() + + if !ok { + return nil + } + + for _, buf := range bufs { + if _, err := relayConn.Write(buf); err != nil { + return err + } + } + + return nil +} + +func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + b.endpointsMu.Lock() + b.endpoints[fakeIP] = conn + b.endpointsMu.Unlock() +} + +func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) { + s.endpointsMu.Lock() + defer s.endpointsMu.Unlock() + + delete(s.endpoints, fakeIP) +} + +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, ErrUDPMUXNotSupported +} + +func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder { + return s.activityRecorder +} diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go deleted file mode 100644 index db0249d11..000000000 --- a/client/iface/bind/udp_mux_ios.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build ios - -package bind - -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { - // iOS doesn't support nbnet hooks, so this is a no-op -} diff --git a/client/iface/configurer/name.go b/client/iface/configurer/name.go index 3b9abc0e8..a8469e0b4 100644 --- a/client/iface/configurer/name.go +++ b/client/iface/configurer/name.go @@ -1,4 +1,4 @@ -//go:build linux || windows || freebsd +//go:build linux || windows || freebsd || js || wasip1 package configurer diff --git a/client/iface/configurer/uapi.go b/client/iface/configurer/uapi.go index 4801841de..f85c7852a 100644 --- a/client/iface/configurer/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build !windows && !js package configurer diff --git a/client/iface/configurer/uapi_js.go b/client/iface/configurer/uapi_js.go new file mode 100644 index 000000000..d0188eb35 --- /dev/null +++ b/client/iface/configurer/uapi_js.go @@ -0,0 +1,23 @@ +package configurer + +import ( + "net" +) + +type noopListener struct{} + +func (n *noopListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (n *noopListener) Close() error { + return nil +} + +func (n *noopListener) Addr() net.Addr { + return nil +} + +func openUAPI(deviceName string) (net.Listener, error) { + return &noopListener{}, nil +} diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 171458e38..f744e0127 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/monotime" - nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) { if err != nil { return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) } + + // If sec is 0 (Unix epoch), return zero time instead + // This indicates no handshake has occurred + if sec == 0 { + return time.Time{}, nil + } + return time.Unix(sec, 0), nil } @@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) { } func getFwmark() int { - if nbnet.AdvancedRouting() { + if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { return nbnet.ControlPlaneMark } return 0 diff --git a/client/iface/device.go b/client/iface/device.go index ca6dda2c2..921f0ea98 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -7,14 +7,14 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index fe3b9f82e..a731684cc 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type WGTunDevice struct { name string device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string } return t.configurer, nil } -func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index cce9d42df..390efe088 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -26,7 +27,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 168985b5e..96e4c8bcf 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -28,7 +29,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 00a72bcc6..cdac43a53 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -12,11 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/sharedsock" - nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -31,9 +31,9 @@ type TunKernelDevice struct { link *wgLink udpMuxConn net.PacketConn - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault - filterFn bind.FilterFn + filterFn udpmux.FilterFn } func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { @@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { return configurer, nil } -func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } - var udpConn net.PacketConn = rawSock - if !nbnet.AdvancedRouting() { - udpConn = nbnet.WrapPacketConn(rawSock) - } - - bindParams := bind.UniversalUDPMuxParams{ - UDPConn: udpConn, + bindParams := udpmux.UniversalUDPMuxParams{ + UDPConn: nbnet.WrapPacketConn(rawSock), Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, MTU: t.mtu, } - mux := bind.NewUniversalUDPMuxDefault(bindParams) + mux := udpmux.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) t.udpMuxConn = rawSock t.udpMux = mux diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f41331ff7..e37321b68 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,19 +1,28 @@ package device import ( + "errors" "fmt" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) +type Bind interface { + conn.Bind + GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) + ActivityRecorder() *bind.ActivityRecorder +} + type TunNetstackDevice struct { name string address wgaddr.Address @@ -21,18 +30,18 @@ type TunNetstackDevice struct { key string mtu uint16 listenAddress string - iceBind *bind.ICEBind + bind Bind device *device.Device filteredDevice *FilteredDevice nsTun *nbnetstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer net *netstack.Net } -func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -40,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: iceBind, + bind: bind, } } @@ -65,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { t.device = device.NewDevice( t.filteredDevice, - t.iceBind, + t.bind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() @@ -80,7 +89,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -90,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } - udpMux, err := t.iceBind.GetICEMux() - if err != nil { + udpMux, err := t.bind.GetICEMux() + if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) { return nil, err } - t.udpMux = udpMux + + if udpMux != nil { + t.udpMux = udpMux + } + log.Debugf("netstack device is ready to use") return udpMux, nil } diff --git a/client/iface/device/device_netstack_test.go b/client/iface/device/device_netstack_test.go new file mode 100644 index 000000000..52059602f --- /dev/null +++ b/client/iface/device/device_netstack_test.go @@ -0,0 +1,27 @@ +package device + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestNewNetstackDevice(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24") + + relayBind := bind.NewRelayBindJS() + nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr()) + + cfgr, err := nsTun.Create() + if err != nil { + t.Fatalf("failed to create netstack device: %v", err) + } + if cfgr == nil { + t.Fatal("expected non-nil configurer") + } +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 8d30112ae..4cdd70a32 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -25,7 +26,7 @@ type USPDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index de258868f..f1023bc0a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type TunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 39b5c28ae..4649b8b97 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -5,14 +5,14 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/iface.go b/client/iface/iface.go index 9a42223a1..609572561 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -16,9 +16,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -61,7 +61,7 @@ type WGIFaceOpts struct { MTU uint16 MobileArgs *device.MobileIFaceArguments TransportNet transport.Net - FilterFn bind.FilterFn + FilterFn udpmux.FilterFn DisableDNS bool } @@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface { // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/iface_destroy_js.go b/client/iface/iface_destroy_js.go new file mode 100644 index 000000000..b443273c3 --- /dev/null +++ b/client/iface/iface_destroy_js.go @@ -0,0 +1,6 @@ +package iface + +// Destroy is a no-op on WASM +func (w *WGIface) Destroy() error { + return nil +} diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 26952f48d..3b68f63f2 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index 7dd74d571..9f21ec950 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go new file mode 100644 index 000000000..a342bd579 --- /dev/null +++ b/client/iface/iface_new_freebsd.go @@ -0,0 +1,41 @@ +//go:build freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) + return wgIFace, nil + } + + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 06ccf0be1..5d6a32e39 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_js.go b/client/iface/iface_new_js.go new file mode 100644 index 000000000..ad913ab04 --- /dev/null +++ b/client/iface/iface_new_js.go @@ -0,0 +1,27 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + relayBind := bind.NewRelayBindJS() + + wgIface := &WGIface{ + tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), + } + + return wgIface, nil +} diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_linux.go similarity index 89% rename from client/iface/iface_new_unix.go rename to client/iface/iface_new_linux.go index 493144f13..d84035403 100644 --- a/client/iface/iface_new_unix.go +++ b/client/iface/iface_new_linux.go @@ -1,4 +1,4 @@ -//go:build (linux && !android) || freebsd +//go:build linux && !android package iface @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 349c5b33b..dfd9028e7 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index cdbf975b1..dd8cf29a3 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -1,3 +1,5 @@ +//go:build !js + package netstack import ( diff --git a/client/iface/netstack/env_js.go b/client/iface/netstack/env_js.go new file mode 100644 index 000000000..05c20f036 --- /dev/null +++ b/client/iface/netstack/env_js.go @@ -0,0 +1,12 @@ +package netstack + +const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE" + +// IsEnabled always returns true for js since it's the only mode available +func IsEnabled() bool { + return true +} + +func ListenAddr() string { + return "" +} diff --git a/client/iface/bind/udp_muxed_conn.go b/client/iface/udpmux/conn.go similarity index 95% rename from client/iface/bind/udp_muxed_conn.go rename to client/iface/udpmux/conn.go index 7cacf1c31..3aa40caeb 100644 --- a/client/iface/bind/udp_muxed_conn.go +++ b/client/iface/udpmux/conn.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements @@ -16,11 +16,12 @@ import ( ) type udpMuxedConnParams struct { - Mux *UDPMuxDefault - AddrPool *sync.Pool - Key string - LocalAddr net.Addr - Logger logging.LeveledLogger + Mux *SingleSocketUDPMux + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger + CandidateID string } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag @@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error { return err } +func (c *udpMuxedConn) GetCandidateID() string { + return c.params.CandidateID +} + func (c *udpMuxedConn) isClosed() bool { select { case <-c.closedChan: diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go new file mode 100644 index 000000000..27e5e43bc --- /dev/null +++ b/client/iface/udpmux/doc.go @@ -0,0 +1,64 @@ +// Package udpmux provides a custom implementation of a UDP multiplexer +// that allows multiple logical ICE connections to share a single underlying +// UDP socket. This is based on Pion's ICE library, with modifications for +// NetBird's requirements. +// +// # Background +// +// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity +// Establishment) is responsible for discovering candidate network paths +// and maintaining connectivity between peers. Each ICE connection +// normally requires a dedicated UDP socket. However, using one socket +// per candidate can be inefficient and difficult to manage. +// +// This package introduces SingleSocketUDPMux, which allows multiple ICE +// candidate connections (muxed connections) to share a single UDP socket. +// It handles demultiplexing of packets based on ICE ufrag values, STUN +// attributes, and candidate IDs. +// +// # Usage +// +// The typical flow is: +// +// 1. Create a UDP socket (net.PacketConn). +// 2. Construct Params with the socket and optional logger/net stack. +// 3. Call NewSingleSocketUDPMux(params). +// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID) +// to obtain a logical PacketConn. +// 5. Use the returned PacketConn just like a normal UDP connection. +// +// # STUN Message Routing Logic +// +// When a STUN packet arrives, the mux decides which connection should +// receive it using this routing logic: +// +// Primary Routing: Candidate Pair ID +// - Extract the candidate pair ID from the STUN message using +// ice.CandidatePairIDFromSTUN(msg) +// - The target candidate is the locally generated candidate that +// corresponds to the connection that should handle this STUN message +// - If found, use the target candidate ID to lookup the specific +// connection in candidateConnMap +// - Route the message directly to that connection +// +// Fallback Routing: Broadcasting +// When candidate pair ID is not available or lookup fails: +// - Collect connections from addressMap based on source address +// - Find connection using username attribute (ufrag) from STUN message +// - Remove duplicate connections from the list +// - Send the STUN message to all collected connections +// +// # Peer Reflexive Candidate Discovery +// +// When a remote peer sends a STUN message from an unknown source address +// (from a candidate that has not been exchanged via signal), the ICE +// library will: +// - Generate a new peer reflexive candidate for this source address +// - Extract or assign a candidate ID based on the STUN message attributes +// - Create a mapping between the new peer reflexive candidate ID and +// the appropriate local connection +// +// This discovery mechanism ensures that STUN messages from newly discovered +// peer reflexive candidates can be properly routed to the correct local +// connection without requiring fallback broadcasting. +package udpmux diff --git a/client/iface/bind/udp_mux.go b/client/iface/udpmux/mux.go similarity index 65% rename from client/iface/bind/udp_mux.go rename to client/iface/udpmux/mux.go index db7494405..319724926 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/udpmux/mux.go @@ -1,4 +1,4 @@ -package bind +package udpmux import ( "fmt" @@ -22,9 +22,9 @@ import ( const receiveMTU = 8192 -// UDPMuxDefault is an implementation of the interface -type UDPMuxDefault struct { - params UDPMuxParams +// SingleSocketUDPMux is an implementation of the interface +type SingleSocketUDPMux struct { + params Params closedChan chan struct{} closeOnce sync.Once @@ -32,6 +32,9 @@ type UDPMuxDefault struct { // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn + // candidateConnMap maps local candidate IDs to their corresponding connection. + candidateConnMap map[string]*udpMuxedConn + addressMapMu sync.RWMutex addressMap map[string][]*udpMuxedConn @@ -46,8 +49,8 @@ type UDPMuxDefault struct { const maxAddrSize = 512 -// UDPMuxParams are parameters for UDPMux. -type UDPMuxParams struct { +// Params are parameters for UDPMux. +type Params struct { Logger logging.LeveledLogger UDPConn net.PacketConn @@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool { return true } -// NewUDPMuxDefault creates an implementation of UDPMux -func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { +// NewSingleSocketUDPMux creates an implementation of UDPMux +func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { if params.Logger == nil { params.Logger = getLogger() } - mux := &UDPMuxDefault{ - addressMap: map[string][]*udpMuxedConn{}, - params: params, - connsIPv4: make(map[string]*udpMuxedConn), - connsIPv6: make(map[string]*udpMuxedConn), - closedChan: make(chan struct{}, 1), + mux := &SingleSocketUDPMux{ + addressMap: map[string][]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + candidateConnMap: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address @@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return mux } -func (m *UDPMuxDefault) updateLocalAddresses() { +func (m *SingleSocketUDPMux) updateLocalAddresses() { var localAddrsForUnspecified []net.Addr if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) } else if ok && addr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + // with SingleSocketUDPMux, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { @@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() { m.mu.Unlock() } -// LocalAddr returns the listening address of this UDPMuxDefault -func (m *UDPMuxDefault) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this SingleSocketUDPMux +func (m *SingleSocketUDPMux) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on -func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { +func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { m.updateLocalAddresses() m.mu.Lock() @@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network address // creates the connection if an existing one can't be found -func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { +func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { // don't check addr for mux using unspecified address m.mu.Lock() lenLocalAddrs := len(m.localAddrsForUnspecified) @@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return conn, nil } - c := m.createMuxedConn(ufrag) + c := m.createMuxedConn(ufrag, candidateID) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() + m.candidateConnMap[candidateID] = c + if isIPv6 { m.connsIPv6[ufrag] = c } else { @@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er } // RemoveConnByUfrag stops and removes the muxed packet connection -func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { +func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock @@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } m.mu.Unlock() @@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { } // IsClosed returns true if the mux had been closed -func (m *UDPMuxDefault) IsClosed() bool { +func (m *SingleSocketUDPMux) IsClosed() bool { select { case <-m.closedChan: return true @@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool { } // Close the mux, no further connections could be created -func (m *UDPMuxDefault) Close() error { +func (m *SingleSocketUDPMux) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() @@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error { return err } -func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { +func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } @@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } -func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { +func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ - Mux: m, - Key: key, - AddrPool: m.pool, - LocalAddr: m.LocalAddr(), - Logger: m.params.Logger, + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + CandidateID: candidateID, }) return c } // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library -func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { - +func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { remoteAddr, ok := addr.(*net.UDPAddr) if !ok { return fmt.Errorf("underlying PacketConn did not return a UDPAddr") } - // If we have already seen this address dispatch to the appropriate destination - // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one - // muxed connection - one for the SRFLX candidate and the other one for the HOST one. - // We will then forward STUN packets to each of these connections. - m.addressMapMu.RLock() + // Try to route to specific candidate connection first + if conn := m.findCandidateConnection(msg); conn != nil { + return conn.writePacket(msg.Raw, remoteAddr) + } + + // Fallback: route to all possible connections + return m.forwardToAllConnections(msg, addr, remoteAddr) +} + +// findCandidateConnection attempts to find the specific connection for a STUN message +func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn { + candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg) + if err != nil { + return nil + } else if !ok { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()] + if !exists { + return nil + } + return conn +} + +// forwardToAllConnections forwards STUN message to all relevant connections +func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error { var destinationConnList []*udpMuxedConn + + // Add connections from address map + m.addressMapMu.RLock() if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } m.addressMapMu.RUnlock() - var isIPv6 bool - if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { - isIPv6 = true + if conn, ok := m.findConnectionByUsername(msg, addr); ok { + // If we have already seen this address dispatch to the appropriate destination + // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one + // muxed connection - one for the SRFLX candidate and the other one for the HOST one. + // We will then forward STUN packets to each of these connections. + if !m.connectionExists(conn, destinationConnList) { + destinationConnList = append(destinationConnList, conn) + } } - // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. - // However, we can take a username attribute from the STUN message which contains ufrag. - // We can use ufrag to identify the destination conn to route packet to. - attr, stunAttrErr := msg.Get(stun.AttrUsername) - if stunAttrErr == nil { - ufrag := strings.Split(string(attr), ":")[0] - - m.mu.Lock() - destinationConn := m.connsIPv4[ufrag] - if isIPv6 { - destinationConn = m.connsIPv6[ufrag] - } - - if destinationConn != nil { - exists := false - for _, conn := range destinationConnList { - if conn.params.Key == destinationConn.params.Key { - exists = true - break - } - } - if !exists { - destinationConnList = append(destinationConnList, destinationConn) - } - } - m.mu.Unlock() - } - - // Forward STUN packets to each destination connections even thought the STUN packet might not belong there. - // It will be discarded by the further ICE candidate logic if so. + // Forward to all found connections for _, conn := range destinationConnList { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { log.Errorf("could not write packet: %v", err) } } - return nil } -func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { +// findConnectionByUsername finds connection using username attribute from STUN message +func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) { + attr, err := msg.Get(stun.AttrUsername) + if err != nil { + return nil, false + } + + ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := isIPv6Address(addr) + + m.mu.Lock() + defer m.mu.Unlock() + + return m.getConn(ufrag, isIPv6) +} + +// connectionExists checks if a connection already exists in the list +func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool { + for _, conn := range conns { + if conn.params.Key == target.params.Key { + return true + } + } + return false +} + +func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { @@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o return } +func isIPv6Address(addr net.Addr) bool { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + return udpAddr.IP.To4() == nil + } + return false +} + type bufferHolder struct { buf []byte } diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/udpmux/mux_generic.go similarity index 76% rename from client/iface/bind/udp_mux_generic.go rename to client/iface/udpmux/mux_generic.go index 63f786d2b..29fc2d834 100644 --- a/client/iface/bind/udp_mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -1,12 +1,12 @@ //go:build !ios -package bind +package udpmux import ( - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { conn.RemoveAddress(addr) diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go new file mode 100644 index 000000000..4cf211d8f --- /dev/null +++ b/client/iface/udpmux/mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package udpmux + +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/udpmux/universal.go similarity index 97% rename from client/iface/bind/udp_mux_universal.go rename to client/iface/udpmux/universal.go index a1f517dcd..43bfedaaa 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/udpmux/universal.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. @@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { - *UDPMuxDefault + *SingleSocketUDPMux params UniversalUDPMuxParams // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents @@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - udpMuxParams := UDPMuxParams{ + udpMuxParams := Params{ Logger: params.Logger, UDPConn: m.params.UDPConn, Net: m.params.Net, } - m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) return m } @@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { - return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { + return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) } // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. @@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A } return nil } - return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) + return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index bf6da72c2..eb585d8a2 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,28 +16,38 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) -type ProxyBind struct { - Bind *bind.ICEBind - - fakeNetIP *netip.AddrPort - wgBindEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool - - pausedMu sync.Mutex - paused bool - isStarted bool - - closeListener *listener.CloseListener +type Bind interface { + SetEndpoint(addr netip.Addr, conn net.Conn) + RemoveEndpoint(addr netip.Addr) + ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte) } -func NewProxyBind(bind *bind.ICEBind) *ProxyBind { +type ProxyBind struct { + bind Bind + + // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address + wgRelayedEndpoint *bind.Endpoint + wgCurrentUsed *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool + + paused bool + pausedCond *sync.Cond + isStarted bool + + closeListener *listener.CloseListener + mtu uint16 +} + +func NewProxyBind(bind Bind, mtu uint16) *ProxyBind { p := &ProxyBind{ - Bind: bind, + bind: bind, closeListener: listener.NewCloseListener(), + pausedCond: sync.NewCond(&sync.Mutex{}), + mtu: mtu + bufsize.WGBufferOverhead, } return p @@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind { // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. +// +// Parameters: +// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages +// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address +// - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - - p.fakeNetIP = fakeNetIP - p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} + p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } + func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return &net.UDPAddr{ - IP: p.fakeNetIP.Addr().AsSlice(), - Port: int(p.fakeNetIP.Port()), - Zone: p.fakeNetIP.Addr().Zone(), - } + return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { @@ -76,17 +86,21 @@ func (p *ProxyBind) Work() { return } - p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) + p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { @@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgCurrentUsed = addrToEndpoint(endpoint) + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} } func (p *ProxyBind) CloseConn() error { @@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { + if p.remoteConn == nil { + return nil + } + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -120,7 +154,12 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.mtu) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - msg := bind.RecvMessage{ - Endpoint: p.wgBindEndpoint, - Buffer: buf[:n], - } - p.Bind.RecvChan <- msg - p.pausedMu.Unlock() + p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n]) + p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index fcdc0189d..858143091 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -6,9 +6,7 @@ import ( "context" "fmt" "net" - "os" "sync" - "syscall" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -18,15 +16,20 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bufsize" + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( loopbackAddr = "127.0.0.1" ) +var ( + localHostNetIP = net.ParseIP("127.0.0.1") +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = p.prepareSenderRawSocket() + p.rawConn, err = rawsocket.PrepareSenderRawSocket() if err != nil { return err } @@ -214,57 +217,17 @@ generatePort: return p.lastUsedPort, nil } -func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { - // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) - if err != nil { - return nil, fmt.Errorf("creating raw socket failed: %w", err) - } - - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) - } - - // Bind the socket to the "lo" interface. - err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") - if err != nil { - return nil, fmt.Errorf("binding to lo interface failed: %w", err) - } - - // Set the fwmark on the socket. - err = nbnet.SetSocketOpt(fd) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - - // Convert the file descriptor to a PacketConn. - file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) - if file == nil { - return nil, fmt.Errorf("converting fd to file failed") - } - packetConn, err := net.FilePacketConn(file) - if err != nil { - return nil, fmt.Errorf("converting file to packet conn failed: %w", err) - } - - return packetConn, nil -} - -func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { - localhost := net.ParseIP("127.0.0.1") - +func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { payload := gopacket.Payload(data) ipH := &layers.IPv4{ - DstIP: localhost, - SrcIP: localhost, + DstIP: localHostNetIP, + SrcIP: endpointAddr.IP, Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, } udpH := &layers.UDP{ - SrcPort: layers.UDPPort(port), + SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } @@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { if err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 3d71b01bd..ff44d30c0 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -18,41 +18,42 @@ import ( // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { - WgeBPFProxy *WGEBPFProxy + wgeBPFProxy *WGEBPFProxy remoteConn net.Conn ctx context.Context cancel context.CancelFunc - wgEndpointAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + wgEndpointCurrentUsedAddr *net.UDPAddr - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { +func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { return &ProxyWrapper{ - WgeBPFProxy: WgeBPFProxy, + wgeBPFProxy: proxy, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } } - func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) + addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - p.wgEndpointAddr = addr + p.wgRelayedEndpointAddr = addr return err } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { - return p.wgEndpointAddr + return p.wgRelayedEndpointAddr } func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { @@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyWrapper) Pause() { @@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() { } log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgEndpointCurrentUsedAddr = endpoint + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map -func (e *ProxyWrapper) CloseConn() error { - if e.cancel == nil { +func (p *ProxyWrapper) CloseConn() error { + if p.cancel == nil { return fmt.Errorf("proxy not started") } - e.cancel() + p.cancel() - e.closeListener.SetCloseListener(nil) + p.closeListener.SetCloseListener(nil) - if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("close remote conn: %w", err) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to close remote conn: %w", err) } return nil } func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { - defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) - buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) for { n, err := p.readFromRemote(ctx, buf) if err != nil { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) - p.pausedMu.Unlock() + err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { @@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } p.closeListener.Notify() if !errors.Is(err, io.EOF) { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) } return 0, err } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 63bc2ed24..ad2807546 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy { } return ebpf.NewProxyWrapper(w.ebpfProxy) - } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go deleted file mode 100644 index 039f1cd3a..000000000 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ /dev/null @@ -1,31 +0,0 @@ -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" -) - -// KernelFactory todo: check eBPF support on FreeBSD -type KernelFactory struct { - wgPort int - mtu uint16 -} - -func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { - log.Infof("WireGuard Proxy Factory will produce UDP proxy") - f := &KernelFactory{ - wgPort: wgPort, - mtu: mtu, - } - - return f -} - -func (w *KernelFactory) GetProxy() Proxy { - return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu) -} - -func (w *KernelFactory) Free() error { - return nil -} diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index 141b4c1f9..a1b1c34d7 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -3,24 +3,25 @@ package wgproxy import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" ) type USPFactory struct { - bind *bind.ICEBind + bind proxyBind.Bind + mtu uint16 } -func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { +func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory { log.Infof("WireGuard Proxy Factory will produce bind proxy") f := &USPFactory{ - bind: iceBind, + bind: bind, + mtu: mtu, } return f } func (w *USPFactory) GetProxy() Proxy { - return proxyBind.NewProxyBind(w.bind) + return proxyBind.NewProxyBind(w.bind, w.mtu) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index c2879877e..3c8dfd30e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -11,6 +11,11 @@ type Proxy interface { EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + + //RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused + //and rewrite the src address to the endpoint address. + //With this logic can avoid the package loss from relayed connections. + RedirectAs(endpoint *net.UDPAddr) CloseConn() error SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 5add503e1..dd24d1cdc 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -3,54 +3,82 @@ package wgproxy import ( - "context" - "os" - "testing" + "fmt" + "net" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) -func TestProxyCloseByRemoteConnEBPF(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") != "true" { - t.Skip("Skipping test as it requires root privileges") - } - ctx := context.Background() +func seedProxies() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - - tests := []struct { - name string - proxy Proxy - }{ - { - name: "ebpf proxy", - proxy: &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - }, - }, + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, } + pl = append(pl, pEbpf) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) - if err != nil { - t.Errorf("error: %v", err) - } - - _ = relayedConn.Close() - if err := tt.proxy.CloseConn(); err != nil { - t.Errorf("error: %v", err) - } - }) + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, } + pl = append(pl, pUDP) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + + ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) + if err := ebpfProxy.Listen(); err != nil { + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) + } + + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, + } + pl = append(pl, pEbpf) + + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, + } + pl = append(pl, pUDP) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind, 0), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + + return pl, nil } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go new file mode 100644 index 000000000..ad375ccde --- /dev/null +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -0,0 +1,39 @@ +//go:build !linux + +package wgproxy + +import ( + "net" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +func seedProxies() ([]proxyInstance, error) { + // todo extend with Bind proxy + pl := make([]proxyInstance, 0) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind, 0), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + return pl, nil +} diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 76e5ed6f7..1aeab66b7 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -1,5 +1,3 @@ -//go:build linux - package wgproxy import ( @@ -7,12 +5,9 @@ import ( "io" "net" "os" - "runtime" "testing" "time" - "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -22,6 +17,14 @@ func TestMain(m *testing.M) { os.Exit(code) } +type proxyInstance struct { + name string + proxy Proxy + wgPort int + endpointAddr *net.UDPAddr + closeFn func() error +} + type mocConn struct { closeChan chan struct{} closed bool @@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error { func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() - tests := []struct { - name string - proxy Proxy - }{ - { - name: "userspace proxy", - proxy: udpProxy.NewWGUDPProxy(51830, 1280), - }, + tests, err := seedProxyForProxyCloseByRemoteConn() + if err != nil { + t.Fatalf("error: %v", err) } - if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { - ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) - if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) - } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) - - tests = append(tests, struct { - name string - proxy Proxy - }{ - name: "ebpf proxy", - proxy: proxyWrapper, - }) - } + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + defer func() { + _ = relayedConn.Close() + }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892") relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) if err != nil { t.Errorf("error: %v", err) } @@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }) } } + +// TestProxyRedirect todo extend the proxies with Bind proxy +func TestProxyRedirect(t *testing.T) { + tests, err := seedProxies() + if err != nil { + t.Fatalf("error: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) + if err := tt.closeFn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} + +func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { + t.Helper() + + msgHelloFromRelay := []byte("hello from relay") + msgRedirected := [][]byte{ + []byte("hello 1. to p2p"), + []byte("hello 2. to p2p"), + []byte("hello 3. to p2p"), + } + + dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: wgPort}) + if err != nil { + t.Fatalf("failed to listen on udp port: %s", err) + } + + relayedServer, _ := net.ListenUDP("udp", + &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, + ) + + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + + defer func() { + _ = dummyWgListener.Close() + _ = relayedConn.Close() + _ = relayedServer.Close() + }() + + if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { + t.Errorf("error: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }() + + proxy.Work() + + if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil { + t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err) + } + + n, err := dummyWgListener.Read(make([]byte, 1024)) + if err != nil { + t.Errorf("error: %v", err) + } + + if n != len(msgHelloFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n) + } + + p2pEndpointAddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 56), + Port: 1234, + } + proxy.RedirectAs(p2pEndpointAddr) + + for _, msg := range msgRedirected { + if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil { + t.Errorf("error: %v", err) + } + } + + for i := 0; i < len(msgRedirected); i++ { + buf := make([]byte, 1024) + n, rAddr, err := dummyWgListener.ReadFrom(buf) + if err != nil { + t.Errorf("error: %v", err) + } + + if rAddr.String() != p2pEndpointAddr.String() { + t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String()) + } + if string(buf[:n]) != string(msgRedirected[i]) { + t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n])) + } + } +} diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go new file mode 100644 index 000000000..a11ac46d5 --- /dev/null +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -0,0 +1,50 @@ +//go:build linux && !android + +package rawsocket + +import ( + "fmt" + "net" + "os" + "syscall" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func PrepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return nil, fmt.Errorf("creating raw socket failed: %w", err) + } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) + if err != nil { + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + + // Bind the socket to the "lo" interface. + err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") + if err != nil { + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = nbnet.SetSocketOpt(fd) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) + } + + return packetConn, nil +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index be65e2b27..4ef2f19c4 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package udp import ( @@ -21,16 +23,18 @@ type WGUDPProxy struct { localWGListenPort int mtu uint16 - remoteConn net.Conn - localConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + remoteConn net.Conn + localConn net.Conn + srcFakerConn *SrcFaker + sendPkg func(data []byte) (int, error) + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } @@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { p := &WGUDPProxy{ localWGListenPort: wgPort, mtu: mtu, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } return p @@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem p.ctx, p.cancel = context.WithCancel(ctx) p.localConn = localConn + p.sendPkg = p.localConn.Write p.remoteConn = remoteConn return err @@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + p.sendPkg = p.localConn.Write + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } if !p.isStarted { p.isStarted = true go p.proxyToRemote(p.ctx) go p.proxyToLocal(p.ctx) } + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // Pause pauses the proxy from receiving data from the remote peer @@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +// RedirectAs start to use the fake sourced raw socket as package sender +func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + defer func() { + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + }() + + p.paused = false + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } + srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create src faker conn: %s", err) + // fallback to continue without redirecting + p.paused = true + return + } + p.srcFakerConn = srcFakerConn + p.sendPkg = p.srcFakerConn.SendPkg } // CloseConn close the localConn @@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error { } func (p *WGUDPProxy) close() error { + var result *multierror.Error + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error { p.cancel() - var result *multierror.Error + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } @@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err)) + } + } + return cerrors.FormatErrorOrNil(result) } @@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - - _, err = p.localConn.Write(buf[:n]) - p.pausedMu.Unlock() + _, err = p.sendPkg(buf[:n]) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go new file mode 100644 index 000000000..fdc911463 --- /dev/null +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -0,0 +1,101 @@ +//go:build linux && !android + +package udp + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" +) + +var ( + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + localHostNetIPAddr = &net.IPAddr{ + IP: net.ParseIP("127.0.0.1"), + } +) + +type SrcFaker struct { + srcAddr *net.UDPAddr + + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer +} + +func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { + rawSocket, err := rawsocket.PrepareSenderRawSocket() + if err != nil { + return nil, err + } + + ipH, udpH, err := prepareHeaders(dstPort, srcAddr) + if err != nil { + return nil, err + } + + f := &SrcFaker{ + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + } + + return f, nil +} + +func (f *SrcFaker) Close() error { + return f.rawSocket.Close() +} + +func (f *SrcFaker) SendPkg(data []byte) (int, error) { + defer func() { + if err := f.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload) + if err != nil { + return 0, fmt.Errorf("serialize layers: %w", err) + } + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + if err != nil { + return 0, fmt.Errorf("write to raw conn: %w", err) + } + return n, nil +} + +func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { + ipH := &layers.IPv4{ + DstIP: net.ParseIP("127.0.0.1"), + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(srcAddr.Port), + DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port + } + + err := udpH.SetNetworkLayerForChecksum(ipH) + if err != nil { + return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return ipH, udpH, nil +} diff --git a/client/internal/connect.go b/client/internal/connect.go index a6872ca0d..75c1aa75b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -35,7 +35,7 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) @@ -297,10 +297,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan state.Set(StatusConnected) if runningChan != nil { - select { - case runningChan <- struct{}{}: - default: - } + close(runningChan) + runningChan = nil } <-engineCtx.Done() diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index fdc2c3063..a14a01f40 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { - policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - if r.gpo { - policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - } + localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) + gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) singleDomain := []string{domain} - if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) + if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + } + + if r.gpo { + if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure gpo DNS policy: %w", err) + } } log.Debugf("added NRPT entry for domain: %s", domain) @@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) } @@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) } diff --git a/client/internal/dns/server_js.go b/client/internal/dns/server_js.go new file mode 100644 index 000000000..a8bc35d09 --- /dev/null +++ b/client/internal/dns/server_js.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (hostManager, error) { + return &noopHostConfigurator{}, nil +} diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 89d637686..6ef0ab526 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type ServiceViaMemory struct { diff --git a/client/internal/dns/unclean_shutdown_js.go b/client/internal/dns/unclean_shutdown_js.go new file mode 100644 index 000000000..378ffc164 --- /dev/null +++ b/client/internal/dns/unclean_shutdown_js.go @@ -0,0 +1,19 @@ +package dns + +import ( + "context" +) + +type ShutdownState struct{} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} + +func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error { + return nil +} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 6b7dcc05e..def281f28 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type upstreamResolver struct { diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index bf2ee839b..5c7a3fbdd 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "sync" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -11,14 +12,18 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +var ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + listenPort uint16 = 5353 + listenPortMu sync.RWMutex ) const ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - ListenPort = 5353 - dnsTTL = 60 //seconds + dnsTTL = 60 //seconds ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -35,12 +40,20 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder + port uint16 } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { +func ListenPort() uint16 { + listenPortMu.RLock() + defer listenPortMu.RUnlock() + return listenPort +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + port: port, } } @@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) + if m.port > 0 { + listenPortMu.Lock() + listenPort = m.port + listenPortMu.Unlock() + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort}, + Values: []uint16{ListenPort()}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 3dafc76f3..4f2a72c6b 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -29,9 +29,9 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" @@ -168,7 +168,7 @@ type Engine struct { wgInterface WGIface - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -203,6 +203,13 @@ type Engine struct { // auto-update updateManager *updatemanager.UpdateManager + + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + wgIfaceMonitorWg sync.WaitGroup + + // dns forwarder port + dnsFwdPort uint16 } // Peer is an instance of the Connection Peer @@ -245,6 +252,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + dnsFwdPort: dnsfwd.ListenPort(), } sm := profilemanager.NewServiceManager("") @@ -350,6 +358,9 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + // Stop WireGuard interface monitor and wait for it to exit + e.wgIfaceMonitorWg.Wait() + return nil } @@ -466,14 +477,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("initialize dns server: %w", err) } - iceCfg := icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - } + iceCfg := e.createICEConfig() e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr.Start(e.ctx) @@ -486,6 +490,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() + + // monitor WireGuard interface lifecycle and restart engine on changes + e.wgIfaceMonitor = NewWGIfaceMonitor() + e.wgIfaceMonitorWg.Add(1) + + go func() { + defer e.wgIfaceMonitorWg.Done() + + if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { + log.Infof("WireGuard interface monitor: %s, restarting engine", err) + e.restartEngine() + } else if err != nil { + log.Warnf("WireGuard interface monitor: %s", err) + } + }() + return nil } @@ -977,7 +997,6 @@ func (e *Engine) receiveManagementEvents() { e.config.LazyConnectionEnabled, ) - // err = e.mgmClient.Sync(info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. @@ -988,7 +1007,7 @@ func (e *Engine) receiveManagementEvents() { } log.Debugf("stopped receiving updates from Management Service") }() - log.Debugf("connecting to Management Service updates stream") + log.Infof("connecting to Management Service updates stream") } func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { @@ -1093,7 +1112,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1351,14 +1370,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV Addr: e.getRosenpassAddr(), PermissiveMode: e.config.RosenpassPermissive, }, - ICEConfig: icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - }, + ICEConfig: e.createICEConfig(), } serviceDependencies := peer.ServiceDependencies{ @@ -1859,6 +1871,7 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, + forwarderPort uint16, ) { if e.config.DisableServerRoutes { return @@ -1875,16 +1888,20 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - if e.dnsForwardMgr == nil { - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - + switch { + case e.dnsForwardMgr == nil: + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - } else { + case e.dnsFwdPort != forwarderPort: + log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) + e.restartDnsFwd(fwdEntries, forwarderPort) + e.dnsFwdPort = forwarderPort + + default: e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1894,6 +1911,20 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } + +} + +func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { + log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) + // stop and start the forwarder to apply the new port + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/engine_generic.go b/client/internal/engine_generic.go new file mode 100644 index 000000000..34a75e45b --- /dev/null +++ b/client/internal/engine_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for non-WASM environments +func (e *Engine) createICEConfig() icemaker.Config { + return icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.SingleSocketUDPMux, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } +} diff --git a/client/internal/engine_js.go b/client/internal/engine_js.go new file mode 100644 index 000000000..dce3c57fb --- /dev/null +++ b/client/internal/engine_js.go @@ -0,0 +1,18 @@ +//go:build js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for WASM environment. +func (e *Engine) createICEConfig() icemaker.Config { + cfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + return cfg +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 90c8cbc60..344104405 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,10 +26,15 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" @@ -41,10 +46,8 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -84,7 +87,7 @@ type MockWGIface struct { NameFunc func() string AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface - UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error @@ -134,7 +137,7 @@ func (m *MockWGIface) ToInterface() *net.Interface { return m.ToInterfaceFunc() } -func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } @@ -413,7 +416,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) + engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) @@ -1583,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index bf96153ea..690fdb7cc 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -9,9 +9,9 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -24,7 +24,7 @@ type wgIfaceBase interface { Name() string Address() wgaddr.Address ToInterface() *net.Interface - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index dbb4747a5..a4ffa3a25 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,7 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const defaultChannelSize = 100 diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index e28fdf2f4..899faf108 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { return false } diff --git a/client/internal/networkmonitor/check_change_js.go b/client/internal/networkmonitor/check_change_js.go new file mode 100644 index 000000000..640cf7184 --- /dev/null +++ b/client/internal/networkmonitor/check_change_js.go @@ -0,0 +1,12 @@ +package networkmonitor + +import ( + "context" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + // No-op for WASM - network changes don't apply + return nil +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 224a8144c..8db9e58f4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -6,7 +6,6 @@ import ( "math/rand" "net" "net/netip" - "os" "runtime" "sync" "time" @@ -29,10 +28,6 @@ import ( semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) -const ( - defaultWgKeepAlive = 25 * time.Second -) - type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -118,6 +113,8 @@ type Conn struct { // debug purpose dumpState *stateDump + + endpointUpdater *EndpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -130,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), } return conn, nil @@ -174,7 +172,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) - if os.Getenv("NB_FORCE_RELAY") != "true" { + if !isForceRelayed() { conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } @@ -250,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.RemoveWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -376,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn wgProxy.Work() } - if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { + conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) + if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() + if conn.wgProxyRelay != nil { + conn.Log.Debugf("redirect packets from relayed conn to WireGuard") + conn.wgProxyRelay.RedirectAs(ep) + } + conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -410,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { + presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -419,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() { defer conn.wgWatcherWg.Done() conn.workerRelay.EnableWgWatcher(conn.ctx) }() + conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) @@ -478,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { + presharedKey := conn.presharedKey(rci.rosenpassPubKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -546,17 +554,6 @@ func (conn *Conn) onGuardEvent() { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { - presharedKey := conn.presharedKey(remoteRPKey) - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - addr, - presharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -699,10 +696,6 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 000000000..39cb95591 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,105 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + defaultWgKeepAlive = 25 * time.Second + fallbackDelay = 5 * time.Second +) + +type EndpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + // mu protects updateWireGuardPeer and cancelFunc + mu sync.Mutex + cancelFunc func() + updateWg sync.WaitGroup +} + +func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater { + return &EndpointUpdater{ + log: log, + wgConfig: wgConfig, + initiator: initiator, + } +} + +// ConfigureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.initiator { + e.log.Debugf("configure up WireGuard as initiatr") + return e.updateWireGuardPeer(addr, presharedKey) + } + + // prevent to run new update while cancel the previous update + e.waitForCloseTheDelayedUpdate() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + e.updateWg.Add(1) + go e.scheduleDelayedUpdate(ctx, addr, presharedKey) + + e.log.Debugf("configure up WireGuard and wait for handshake") + return e.updateWireGuardPeer(nil, presharedKey) +} + +func (e *EndpointUpdater) RemoveWgPeer() error { + e.mu.Lock() + defer e.mu.Unlock() + + e.waitForCloseTheDelayedUpdate() + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { + if e.cancelFunc == nil { + return + } + + e.cancelFunc() + e.cancelFunc = nil + e.updateWg.Wait() +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) { + defer e.updateWg.Done() + t := time.NewTimer(fallbackDelay) + defer t.Stop() + + select { + case <-ctx.Done(): + return + case <-t.C: + e.mu.Lock() + if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + e.mu.Unlock() + } +} + +func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error { + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + presharedKey, + ) +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go new file mode 100644 index 000000000..32a458d00 --- /dev/null +++ b/client/internal/peer/env.go @@ -0,0 +1,14 @@ +package peer + +import ( + "os" + "strings" +) + +const ( + EnvKeyNBForceRelay = "NB_FORCE_RELAY" +) + +func isForceRelayed() bool { + return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 70850e6eb..09cf9ae63 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -3,6 +3,8 @@ package guard import ( "context" "fmt" + "slices" + "sort" "sync" "time" @@ -24,8 +26,8 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config - currentCandidates []ice.Candidate - candidatesMu sync.Mutex + currentCandidatesAddress []string + candidatesMu sync.Mutex } func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { @@ -115,16 +117,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { cm.candidatesMu.Lock() defer cm.candidatesMu.Unlock() - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates + newAddresses := make([]string, len(newCandidates)) + for i, c := range newCandidates { + newAddresses[i] = c.Address() + } + sort.Strings(newAddresses) + + if len(cm.currentCandidatesAddress) != len(newAddresses) { + cm.currentCandidatesAddress = newAddresses return true } - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } + // Compare elements + if !slices.Equal(cm.currentCandidatesAddress, newAddresses) { + cm.currentCandidatesAddress = newAddresses + return true } return false diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 218872c15..0ed200fda 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,9 +30,10 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + enabledTime time.Time } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() + w.enabledTime = time.Now() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") @@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex onDisconnectedFn() return } + if lastHandshake.IsZero() { + elapsed := handshake.Sub(w.enabledTime).Seconds() + w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + } + lastHandshake = *handshake resetTime := time.Until(handshake.Add(checkPeriod)) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 896c55b6c..eb886a4d3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -9,11 +9,10 @@ import ( "time" "github.com/pion/ice/v4" - "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -55,10 +54,6 @@ type WorkerICE struct { sessionID ICESessionID muxAgent sync.Mutex - StunTurn []*stun.URI - - sentExtraSrflx bool - localUfrag string localPwd string @@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.muxAgent.Unlock() return } - w.sentExtraSrflx = false w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true @@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA w.log.Errorf("error while handling remote candidate") return } + + if shouldAddExtraCandidate(candidate) { + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + + if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil { + w.log.Errorf("error while handling remote candidate") + return + } + } } func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { @@ -209,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { + if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) { + w.onICESelectedCandidatePair(agent, c1, c2) + }); err != nil { return nil, err } @@ -327,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) return } - mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault) if !ok { w.log.Warn("invalid udp mux conversion") return @@ -354,48 +365,19 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() - - if !w.shouldSendExtraSrflxCandidate(candidate) { - return - } - - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - w.log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - w.sentExtraSrflx = true - - go func() { - err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) - if err != nil { - w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) - } - }() } -func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { +func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) - w.muxAgent.Lock() - - pair, err := w.agent.GetSelectedCandidatePair() - if err != nil { - w.log.Warnf("failed to get selected candidate pair: %s", err) - w.muxAgent.Unlock() + pairStat, ok := agent.GetSelectedCandidatePairStats() + if !ok { + w.log.Warnf("failed to get selected candidate pair stats") return } - if pair == nil { - w.log.Warnf("selected candidate pair is nil, cannot proceed") - w.muxAgent.Unlock() - return - } - w.muxAgent.Unlock() - duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second)) + duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second)) if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { w.log.Debugf("failed to update latency for peer: %s", err) return @@ -424,22 +406,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia } } -func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { - if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { - return true - } - return false -} - func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { - isControlling := w.config.LocalKey > w.config.Key - if isControlling { - return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + if isController(w.config) { + return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } } +func shouldAddExtraCandidate(candidate ice.Candidate) bool { + if candidate.Type() != ice.CandidateTypeServerReflexive { + return false + } + + if candidate.Port() == candidate.RelatedAddress().Port { + return false + } + + // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates + // in newer version we generate locally the extra candidate + if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok { + return false + } + return true +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -455,6 +446,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive } for _, e := range candidate.Extensions() { + // overwrite the original candidate ID with the new one to avoid candidate duplication + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = candidate.ID() + } if err := ec.AddExtension(e); err != nil { return nil, err } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 8c3d5a571..fa208716f 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ProbeResult holds the info about the result of a relay probe request diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 9069cdcc5..47c2ffcda 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -24,8 +24,8 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const dnsTimeout = 8 * time.Second @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index a6775c45a..04513bbe4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,9 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager { notifier := notifier.NewNotifier() sysOps := systemops.NewSysOps(config.WGInterface, notifier) + if runtime.GOOS == "windows" && config.WGInterface != nil { + nbnet.SetVPNInterfaceName(config.WGInterface.Name()) + } + dm := &DefaultManager{ ctx: mCTX, stop: cancel, @@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error { return nil } - if err := m.sysOps.CleanupRouting(nil); err != nil { + if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error { ips := resolveURLsToIPs(initialAddresses) - if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil { return fmt.Errorf("setup routing: %w", err) } @@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { - if err := m.sysOps.CleanupRouting(stateManager); err != nil { + if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") } + + if runtime.GOOS == "windows" { + nbnet.SetVPNInterfaceName("") + } } m.mux.Lock() diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index a375ce832..7cb8dae93 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -12,11 +12,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 128afa2a5..26a548634 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -3,7 +3,6 @@ package systemops import ( - "context" "errors" "fmt" "net" @@ -22,7 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net/hooks" ) const localSubnetsCacheTTL = 15 * time.Minute @@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() + hooks.RemoveWriteHooks() + hooks.RemoveCloseHooks() + hooks.RemoveAddressRemoveHooks() if err := r.refCounter.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) @@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := util.GetPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - + beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } @@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } - afterHook := func(connID nbnet.ConnectionID) error { + afterHook := func(connID hooks.ConnectionID) error { if err := r.refCounter.DecrementWithID(string(connID)); err != nil { return fmt.Errorf("remove route reference: %w", err) } @@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M var merr *multierror.Error for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err)) + continue + } + if err := beforeHook("init", prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err)) } } - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } + hooks.AddWriteHook(beforeHook) + hooks.AddCloseHook(afterHook) - var merr *multierror.Error - for _, ip := range resolvedIPs { - merr = multierror.Append(merr, beforeHook(connID, ip.IP)) - } - return nberrors.FormatErrorOrNil(merr) - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.Decrement(prefix); err != nil { return fmt.Errorf("remove route reference: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index c1c1182bc..32ea38a7a 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" ) type dialer interface { @@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) intf, err := net.InterfaceByName(wgInterface.Name()) @@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 10356eae0..99a363371 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -12,14 +12,14 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go new file mode 100644 index 000000000..808507fc9 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -0,0 +1,48 @@ +package systemops + +import ( + "errors" + "net" + "net/netip" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +var ErrRouteNotSupported = errors.New("route operations not supported on js") + +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func GetRoutesFromTable() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +func hasSeparateRouting() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +// GetDetailedRoutesFromTable returns empty routes for WASM. +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + return []DetailedRoute{}, nil +} + +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error { + return nil +} + +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error { + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index c0cef94ba..bd10f131f 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // IPRule contains IP rule information for debugging @@ -94,15 +94,15 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { - if !nbnet.AdvancedRouting() { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) { + if !advancedRouting { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } defer func() { if err != nil { - if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { - if !nbnet.AdvancedRouting() { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if !advancedRouting { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 83b64e82b..905a7bc12 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -1,4 +1,4 @@ -//go:build !linux && !ios +//go:build !linux && !ios && !js package systemops diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index f165f7779..d43c2d5bf 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -20,11 +20,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index ad37f611f..959c697e4 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type PacketExpectation struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 4f836897b..7bce6af80 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "runtime/debug" + "sort" "strconv" "sync" "syscall" @@ -19,9 +20,16 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" ) -const InfiniteLifetime = 0xffffffff +func init() { + nbnet.GetBestInterfaceFunc = GetBestInterface +} + +const ( + InfiniteLifetime = 0xffffffff +) type RouteUpdateType int @@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct { Table [1]MIB_IPFORWARD_ROW2 // Flexible array member } +// candidateRoute represents a potential route for selection during route lookup +type candidateRoute struct { + interfaceIndex uint32 + prefixLength uint8 + routeMetric uint32 + interfaceMetric int +} + // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -177,11 +193,20 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + return r.cleanupRefCounter(stateManager) } @@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { if table != nil { - ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) - if ret != 0 { - log.Warnf("FreeMibTable failed with return code: %d", ret) - } + _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) } } @@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { entryPtr := basePtr + uintptr(i)*entrySize entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - detailed := buildWindowsDetailedRoute(entry) - if detailed != nil { + if detailed := buildWindowsDetailedRoute(entry); detailed != nil { detailedRoutes = append(detailedRoutes, *detailed) } } @@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { return ip } +// parseCandidatesFromTable extracts all matching candidate routes from the routing table +func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute { + var candidates []candidateRoute + entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) + basePtr := uintptr(unsafe.Pointer(&table.Table[0])) + + for i := uint32(0); i < table.NumEntries; i++ { + entryPtr := basePtr + uintptr(i)*entrySize + entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) + + if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil { + candidates = append(candidates, *candidate) + } + } + + return candidates +} + +// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry +// Returns nil if the route doesn't match the destination or should be skipped +func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute { + if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex { + return nil + } + + destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) + if !destPrefix.IsValid() || !destPrefix.Contains(dest) { + return nil + } + + interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family) + + return &candidateRoute{ + interfaceIndex: entry.InterfaceIndex, + prefixLength: entry.DestinationPrefix.PrefixLength, + routeMetric: entry.Metric, + interfaceMetric: interfaceMetric, + } +} + // getInterfaceMetric retrieves the interface metric for a given interface and address family func getInterfaceMetric(interfaceIndex uint32, family int16) int { if interfaceIndex == 0 { @@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int { return int(ipInterfaceRow.Metric) } +// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric +func sortRouteCandidates(candidates []candidateRoute) { + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].prefixLength != candidates[j].prefixLength { + return candidates[i].prefixLength > candidates[j].prefixLength + } + if candidates[i].routeMetric != candidates[j].routeMetric { + return candidates[i].routeMetric < candidates[j].routeMetric + } + return candidates[i].interfaceMetric < candidates[j].interfaceMetric + }) +} + +// GetBestInterface finds the best interface for reaching a destination, +// excluding the VPN interface to avoid routing loops. +// +// Route selection priority: +// 1. Longest prefix match (most specific route) +// 2. Lowest route metric +// 3. Lowest interface metric +func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { + var skipInterfaceIndex int + if vpnIntf != "" { + if iface, err := net.InterfaceByName(vpnIntf); err == nil { + skipInterfaceIndex = iface.Index + } else { + // not critical, if we cannot get ahold of the interface then we won't need to skip it + log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) + } + } + + table, err := getWindowsRoutingTable() + if err != nil { + return nil, fmt.Errorf("get routing table: %w", err) + } + defer freeWindowsRoutingTable(table) + + candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex) + + if len(candidates) == 0 { + return nil, fmt.Errorf("no route to %s", dest) + } + + // Sort routes: prefix length -> route metric -> interface metric + sortRouteCandidates(candidates) + + for _, candidate := range candidates { + iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex)) + if err != nil { + log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err) + continue + } + + if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() { + continue + } + + if iface.Flags&net.FlagUp == 0 { + log.Debugf("interface %s is down, trying next route", iface.Name) + continue + } + + log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d", + dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric) + return iface, nil + } + + return nil, fmt.Errorf("no usable interface found for %s", dest) +} + // formatRouteAge formats the route age in seconds to a human-readable string func formatRouteAge(ageSeconds uint32) string { if ageSeconds == 0 { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 523bd0b0d..3561adec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) var ( diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go index ac5a48e37..57ea32f69 100644 --- a/client/internal/routemanager/util/ip.go +++ b/client/internal/routemanager/util/ip.go @@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { if !ok { return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) } + addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) + prefix := netip.PrefixFrom(addr, addr.BitLen()) return prefix, nil } diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index e80adb42b..8961eaa69 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -5,7 +5,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // Dial connects to the address on the named network. diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index 9ce0a5556..d3be1896f 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -6,7 +6,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ListenPacket listens for incoming packets on the given network and address. diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go new file mode 100644 index 000000000..78d70c15b --- /dev/null +++ b/client/internal/wg_iface_monitor.go @@ -0,0 +1,98 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine +// if the interface is deleted externally while the engine is running. +type WGIfaceMonitor struct { + done chan struct{} +} + +// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. +func NewWGIfaceMonitor() *WGIfaceMonitor { + return &WGIfaceMonitor{ + done: make(chan struct{}), + } +} + +// Start begins monitoring the WireGuard interface. +// It relies on the provided context cancellation to stop. +func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { + defer close(m.done) + + // Skip on mobile platforms as they handle interface lifecycle differently + if runtime.GOOS == "android" || runtime.GOOS == "ios" { + log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) + return false, errors.New("not supported on mobile platforms") + } + + if ifaceName == "" { + log.Debugf("Interface monitor: empty interface name, skipping monitor") + return false, errors.New("empty interface name") + } + + // Get initial interface index to track the specific interface instance + expectedIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) + return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) + } + + log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } + +} + +// getInterfaceIndex returns the index of a network interface by name. +// Returns an error if the interface is not found. +func getInterfaceIndex(name string) (int, error) { + if name == "" { + return 0, fmt.Errorf("empty interface name") + } + ifi, err := net.InterfaceByName(name) + if err != nil { + // Check if it's specifically a "not found" error + if errors.Is(err, &net.OpError{}) { + // On some systems, this might be a "not found" error + return 0, fmt.Errorf("interface not found: %w", err) + } + return 0, fmt.Errorf("failed to lookup interface: %w", err) + } + if ifi == nil { + return 0, fmt.Errorf("interface not found") + } + return ifi.Index, nil +} diff --git a/client/net/conn.go b/client/net/conn.go new file mode 100644 index 000000000..918e7f628 --- /dev/null +++ b/client/net/conn.go @@ -0,0 +1,49 @@ +//go:build !ios + +package net + +import ( + "io" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/net/hooks" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID hooks.ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +func (c *Conn) Close() error { + return closeConn(c.ID, c.Conn) +} + +// TCPConn wraps net.TCPConn to override its Close method to include hook functionality. +type TCPConn struct { + *net.TCPConn + ID hooks.ConnectionID +} + +// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +func (c *TCPConn) Close() error { + return closeConn(c.ID, c.TCPConn) +} + +// closeConn is a helper function to close connections and execute close hooks. +func closeConn(id hooks.ConnectionID, conn io.Closer) error { + err := conn.Close() + + closeHooks := hooks.GetCloseHooks() + for _, hook := range closeHooks { + if err := hook(id); err != nil { + log.Errorf("Error executing close hook: %v", err) + } + } + + return err +} diff --git a/client/net/dial.go b/client/net/dial.go new file mode 100644 index 000000000..041a00e5d --- /dev/null +++ b/client/net/dial.go @@ -0,0 +1,82 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + udpConn, ok := c.Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.TCPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + tcpConn, ok := c.Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn) + } + return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/dial_ios.go b/client/net/dial_ios.go similarity index 100% rename from util/net/dial_ios.go rename to client/net/dial_ios.go diff --git a/util/net/dialer.go b/client/net/dialer.go similarity index 99% rename from util/net/dialer.go rename to client/net/dialer.go index 0786c667e..29bec05a7 100644 --- a/util/net/dialer.go +++ b/client/net/dialer.go @@ -16,6 +16,5 @@ func NewDialer() *Dialer { Dialer: &net.Dialer{}, } dialer.init() - return dialer } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go new file mode 100644 index 000000000..2e1eb53d8 --- /dev/null +++ b/client/net/dialer_dial.go @@ -0,0 +1,87 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + + if CustomRoutingDisabled() || AdvancedRouting() { + return d.Dialer.DialContext(ctx, network, address) + } + + connID := hooks.GenerateConnID() + if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error { + if ctx.Err() != nil { + return ctx.Err() + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + + resolver := customResolver + if resolver == nil { + resolver = net.DefaultResolver + } + + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var merr *multierror.Error + for _, ip := range ips { + prefix, err := util.GetPrefixFromIP(ip.IP) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err)) + continue + } + for _, hook := range writeHooks { + if err := hook(connID, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err)) + } + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/dialer_init_android.go b/client/net/dialer_init_android.go similarity index 100% rename from util/net/dialer_init_android.go rename to client/net/dialer_init_android.go diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go new file mode 100644 index 000000000..18ebc6ad1 --- /dev/null +++ b/client/net/dialer_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (d *Dialer) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/dialer_init_linux.go b/client/net/dialer_init_linux.go similarity index 100% rename from util/net/dialer_init_linux.go rename to client/net/dialer_init_linux.go diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go new file mode 100644 index 000000000..6eefe5b1e --- /dev/null +++ b/client/net/dialer_init_windows.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyUnicastIFToSocket +} diff --git a/util/net/env.go b/client/net/env.go similarity index 94% rename from util/net/env.go rename to client/net/env.go index 32425665d..8f326ca88 100644 --- a/util/net/env.go +++ b/client/net/env.go @@ -11,6 +11,7 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) // CustomRoutingDisabled returns true if custom routing is disabled. diff --git a/client/net/env_android.go b/client/net/env_android.go new file mode 100644 index 000000000..9d89951a1 --- /dev/null +++ b/client/net/env_android.go @@ -0,0 +1,24 @@ +//go:build android + +package net + +// Init initializes the network environment for Android +func Init() { + // No initialization needed on Android +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on Android since we cannot handle routes dynamically. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on Android +func SetVPNInterfaceName(name string) { + // No-op on Android - not needed for Android VPN service +} + +// GetVPNInterfaceName returns empty string on Android +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_generic.go b/client/net/env_generic.go new file mode 100644 index 000000000..f467930c3 --- /dev/null +++ b/client/net/env_generic.go @@ -0,0 +1,23 @@ +//go:build !linux && !windows && !android + +package net + +// Init initializes the network environment (no-op on non-Linux/Windows platforms) +func Init() { + // No-op on non-Linux/Windows platforms +} + +// AdvancedRouting returns false on non-Linux/Windows platforms +func AdvancedRouting() bool { + return false +} + +// SetVPNInterfaceName is a no-op on non-Windows platforms +func SetVPNInterfaceName(name string) { + // No-op on non-Windows platforms +} + +// GetVPNInterfaceName returns empty string on non-Windows platforms +func GetVPNInterfaceName() string { + return "" +} diff --git a/util/net/env_linux.go b/client/net/env_linux.go similarity index 86% rename from util/net/env_linux.go rename to client/net/env_linux.go index 3159f6462..82d9a74a8 100644 --- a/util/net/env_linux.go +++ b/client/net/env_linux.go @@ -17,8 +17,7 @@ import ( const ( // these have the same effect, skip socket env supported for backward compatibility - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" - envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" ) var advancedRoutingSupported bool @@ -27,6 +26,7 @@ func Init() { advancedRoutingSupported = checkAdvancedRoutingSupport() } +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes func AdvancedRouting() bool { return advancedRoutingSupported } @@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool { } func CheckFwmarkSupport() bool { - // temporarily enable advanced routing to check fwmarks are supported + // temporarily enable advanced routing to check if fwmarks are supported old := advancedRoutingSupported advancedRoutingSupported = true defer func() { @@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool { } return true } + +// SetVPNInterfaceName is a no-op on Linux +func SetVPNInterfaceName(name string) { + // No-op on Linux - not needed for fwmark-based routing +} + +// GetVPNInterfaceName returns empty string on Linux +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_windows.go b/client/net/env_windows.go new file mode 100644 index 000000000..7e8868ba5 --- /dev/null +++ b/client/net/env_windows.go @@ -0,0 +1,67 @@ +//go:build windows + +package net + +import ( + "os" + "strconv" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +var ( + vpnInterfaceName string + vpnInitMutex sync.RWMutex + + advancedRoutingSupported bool +) + +func Init() { + advancedRoutingSupported = checkAdvancedRoutingSupport() +} + +func checkAdvancedRoutingSupport() bool { + var err error + var legacyRouting bool + if val := os.Getenv(envUseLegacyRouting); val != "" { + legacyRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + } + } + + if legacyRouting || netstack.IsEnabled() { + log.Info("advanced routing has been requested to be disabled") + return false + } + + log.Info("system supports advanced routing") + + return true +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes +func AdvancedRouting() bool { + return advancedRoutingSupported +} + +// GetVPNInterfaceName returns the stored VPN interface name +func GetVPNInterfaceName() string { + vpnInitMutex.RLock() + defer vpnInitMutex.RUnlock() + return vpnInterfaceName +} + +// SetVPNInterfaceName sets the VPN interface name for lazy initialization +func SetVPNInterfaceName(name string) { + vpnInitMutex.Lock() + defer vpnInitMutex.Unlock() + vpnInterfaceName = name + + if name != "" { + log.Infof("VPN interface name set to %s for route exclusion", name) + } +} diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go new file mode 100644 index 000000000..93d8e18ef --- /dev/null +++ b/client/net/hooks/hooks.go @@ -0,0 +1,93 @@ +package hooks + +import ( + "net/netip" + "slices" + "sync" + + "github.com/google/uuid" +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} + +type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error +type CloseHookFunc func(connID ConnectionID) error +type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + +var ( + hooksMutex sync.RWMutex + + writeHooks []WriteHookFunc + closeHooks []CloseHookFunc + addressRemoveHooks []AddressRemoveHookFunc +) + +// AddWriteHook allows adding a new hook to be executed before writing/dialing. +func AddWriteHook(hook WriteHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = append(writeHooks, hook) +} + +// AddCloseHook allows adding a new hook to be executed on connection close. +func AddCloseHook(hook CloseHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = append(closeHooks, hook) +} + +// RemoveWriteHooks removes all write hooks. +func RemoveWriteHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = nil +} + +// RemoveCloseHooks removes all close hooks. +func RemoveCloseHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = nil +} + +// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddAddressRemoveHook(hook AddressRemoveHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = append(addressRemoveHooks, hook) +} + +// RemoveAddressRemoveHooks removes all listener address hooks. +func RemoveAddressRemoveHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = nil +} + +// GetWriteHooks returns a copy of the current write hooks. +func GetWriteHooks() []WriteHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(writeHooks) +} + +// GetCloseHooks returns a copy of the current close hooks. +func GetCloseHooks() []CloseHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(closeHooks) +} + +// GetAddressRemoveHooks returns a copy of the current listener address remove hooks. +func GetAddressRemoveHooks() []AddressRemoveHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(addressRemoveHooks) +} diff --git a/client/net/listen.go b/client/net/listen.go new file mode 100644 index 000000000..da7262806 --- /dev/null +++ b/client/net/listen.go @@ -0,0 +1,47 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *PacketConn: + // Legacy routing: wrapped connection for hooks + udpConn, ok := c.PacketConn.(*net.UDPConn) + if !ok { + if err := c.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/listen_ios.go b/client/net/listen_ios.go similarity index 100% rename from util/net/listen_ios.go rename to client/net/listen_ios.go diff --git a/util/net/listener.go b/client/net/listener.go similarity index 81% rename from util/net/listener.go rename to client/net/listener.go index f4d769f58..4c2f53c05 100644 --- a/util/net/listener.go +++ b/client/net/listener.go @@ -7,14 +7,12 @@ import ( // ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before // responding via the socket and after closing. This can be used to bypass the VPN for listeners. type ListenerConfig struct { - *net.ListenConfig + net.ListenConfig } // NewListener creates a new ListenerConfig instance. func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } + listener := &ListenerConfig{} listener.init() return listener diff --git a/util/net/listener_init_android.go b/client/net/listener_init_android.go similarity index 100% rename from util/net/listener_init_android.go rename to client/net/listener_init_android.go diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go new file mode 100644 index 000000000..4f8f17ab2 --- /dev/null +++ b/client/net/listener_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (l *ListenerConfig) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/listener_init_linux.go b/client/net/listener_init_linux.go similarity index 100% rename from util/net/listener_init_linux.go rename to client/net/listener_init_linux.go diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go new file mode 100644 index 000000000..a9399b5f1 --- /dev/null +++ b/client/net/listener_init_windows.go @@ -0,0 +1,8 @@ +package net + +func (l *ListenerConfig) init() { + // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses. + // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case + // the interface will be selected that serves the default route. + l.ListenConfig.Control = applyUnicastIFToSocket +} diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go new file mode 100644 index 000000000..0bb5ad67d --- /dev/null +++ b/client/net/listener_listen.go @@ -0,0 +1,153 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() || AdvancedRouting() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := hooks.GenerateConnID() + + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.UDPConn) +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *PacketConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen()) + + addressRemoveHooks := hooks.GetAddressRemoveHooks() + if len(addressRemoveHooks) == 0 { + return + } + + for _, hook := range addressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + +// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality +func WrapPacketConn(conn net.PacketConn) net.PacketConn { + if AdvancedRouting() { + // hooks not required for advanced routing + return conn + } + return &PacketConn{ + PacketConn: conn, + ID: hooks.GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error { + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded { + return nil + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr) + } + + prefix, err := util.GetPrefixFromIP(udpAddr.IP) + if err != nil { + return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err) + } + + log.Debugf("Listener resolved IP for %s: %s", addr, prefix) + + var merr *multierror.Error + for _, hook := range writeHooks { + if err := hook(id, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/listener_listen_ios.go b/client/net/listener_listen_ios.go similarity index 100% rename from util/net/listener_listen_ios.go rename to client/net/listener_listen_ios.go diff --git a/util/net/net.go b/client/net/net.go similarity index 81% rename from util/net/net.go rename to client/net/net.go index fdcf4ee6a..a97de9d59 100644 --- a/util/net/net.go +++ b/client/net/net.go @@ -5,8 +5,6 @@ import ( "math/big" "net" "net/netip" - - "github.com/google/uuid" ) const ( @@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool { return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper } -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -type AddHookFunc func(connID ConnectionID, IP net.IP) error -type RemoveHookFunc func(connID ConnectionID) error - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} - func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP addr := network.Addr().AsSlice() diff --git a/util/net/net_linux.go b/client/net/net_linux.go similarity index 100% rename from util/net/net_linux.go rename to client/net/net_linux.go diff --git a/util/net/net_test.go b/client/net/net_test.go similarity index 100% rename from util/net/net_test.go rename to client/net/net_test.go diff --git a/client/net/net_windows.go b/client/net/net_windows.go new file mode 100644 index 000000000..649d83aaf --- /dev/null +++ b/client/net/net_windows.go @@ -0,0 +1,284 @@ +package net + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "syscall" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + IpUnicastIf = 31 + Ipv6UnicastIf = 31 + + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options + Ipv6V6only = 27 +) + +// GetBestInterfaceFunc is set at runtime to avoid import cycle +var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error) + +// nativeToBigEndian converts a uint32 from native byte order to big-endian +func nativeToBigEndian(v uint32) uint32 { + return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24 +} + +// parseDestinationAddress parses the destination address from various formats +func parseDestinationAddress(network, address string) (netip.Addr, error) { + if address == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + if addrPort, err := netip.ParseAddrPort(address); err == nil { + return addrPort.Addr(), nil + } + + if dest, err := netip.ParseAddr(address); err == nil { + return dest, nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + // No port, treat whole string as host + host = address + } + + if host == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil || len(ips) == 0 { + return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err) + } + + dest, ok := netip.AddrFromSlice(ips[0].IP) + if !ok { + return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP) + } + + if ips[0].Zone != "" { + dest = dest.WithZone(ips[0].Zone) + } + + return dest, nil +} + +func getInterfaceFromZone(zone string) *net.Interface { + if zone == "" { + return nil + } + + idx, err := strconv.Atoi(zone) + if err != nil { + log.Debugf("invalid zone format for Windows (expected numeric): %s", zone) + return nil + } + + iface, err := net.InterfaceByIndex(idx) + if err != nil { + log.Debugf("failed to get interface by index %d from zone: %v", idx, err) + return nil + } + + return iface +} + +type interfaceSelection struct { + iface4 *net.Interface + iface6 *net.Interface +} + +func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection { + iface := getInterfaceFromZone(zone) + if iface == nil { + return nil + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface} + } + return &interfaceSelection{iface4: iface} +} + +func selectInterfaceForUnspecified() (*interfaceSelection, error) { + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + var result interfaceSelection + vpnIfaceName := GetVPNInterfaceName() + + if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil { + result.iface4 = iface4 + } else { + log.Debugf("No IPv4 default route found: %v", err) + } + + if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil { + result.iface6 = iface6 + } else { + log.Debugf("No IPv6 default route found: %v", err) + } + + if result.iface4 == nil && result.iface6 == nil { + return nil, errors.New("no default routes found") + } + + return &result, nil +} + +func selectInterface(dest netip.Addr) (*interfaceSelection, error) { + if zone := dest.Zone(); zone != "" { + if selection := selectInterfaceForZone(dest, zone); selection != nil { + return selection, nil + } + } + + if dest.IsUnspecified() { + return selectInterfaceForUnspecified() + } + + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName()) + if err != nil { + return nil, fmt.Errorf("find route for %s: %w", dest, err) + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface}, nil + } + return &interfaceSelection{iface4: iface}, nil +} + +func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error { + ifaceIndexBE := nativeToBigEndian(uint32(iface.Index)) + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil { + return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error { + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil { + return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error { + // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.) + // Never generic ones (udp, tcp, ip) + + switch { + case strings.HasSuffix(network, "4"): + // IPv4-only socket (udp4, tcp4, ip4) + return setUnicastIfIPv4(fd, network, selection, address) + + case strings.HasSuffix(network, "6"): + // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only + return setUnicastIfIPv6(fd, network, selection, address) + } + + // Shouldn't reach here based on Go's documented behavior + return fmt.Errorf("unexpected network type: %s", network) +} + +func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error { + if selection.iface4 == nil { + return nil + } + + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address) + return nil +} + +func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error { + isDualStack := checkDualStack(fd) + + // For dual-stack sockets, also set the IPv4 option + if isDualStack && selection.iface4 != nil { + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address) + } + + if selection.iface6 == nil { + return nil + } + + if err := setIPv6UnicastIF(fd, selection.iface6); err != nil { + return err + } + + log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address) + return nil +} + +func checkDualStack(fd uintptr) bool { + var v6Only int + v6OnlyLen := int32(unsafe.Sizeof(v6Only)) + err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen) + return err == nil && v6Only == 0 +} + +// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address +func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + dest, err := parseDestinationAddress(network, address) + if err != nil { + return err + } + + dest = dest.Unmap() + + if !dest.IsValid() { + return fmt.Errorf("invalid destination address for %s", address) + } + + selection, err := selectInterface(dest) + if err != nil { + return err + } + + var controlErr error + err = c.Control(func(fd uintptr) { + controlErr = setUnicastIf(fd, network, selection, address) + }) + + if err != nil { + return fmt.Errorf("control: %w", err) + } + + return controlErr +} diff --git a/util/net/protectsocket_android.go b/client/net/protectsocket_android.go similarity index 100% rename from util/net/protectsocket_android.go rename to client/net/protectsocket_android.go diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 2422d2683..7c9fa021a 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -2,7 +2,7 @@ set -eEuo pipefail : ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() @@ -39,7 +39,7 @@ wait_for_message() { info "not waiting for log line ${message@Q} due to zero timeout." elif test -n "${log_file_path}"; then info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) else info "log file unsupported, sleeping for ${timeout} seconds..." sleep "${timeout}" @@ -81,7 +81,7 @@ wait_for_daemon_startup() { login_if_needed() { local timeout="${1}" - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then info "already logged in, skipping 'netbird up'..." else info "logging in..." diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index c633afc83..841e3c0f7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v5.29.3 +// protoc v6.32.1 // source: daemon.proto package proto @@ -794,8 +794,10 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } +func (x *StatusRequest) GetWaitForReady() bool { + if x != nil && x.WaitForReady != nil { + return *x.WaitForReady + } + return false +} + type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"g\n" + + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + + "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + + "\r_waitForReady\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -5231,6 +5242,7 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 0cd3579b9..5b27b4d98 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -186,6 +186,8 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + optional bool waitForReady = 3; } message StatusResponse{ diff --git a/client/server/server.go b/client/server/server.go index d89c7ce91..e6de608c5 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -65,6 +65,9 @@ type Server struct { mutex sync.Mutex config *profilemanager.Config proto.UnimplementedDaemonServiceServer + clientRunning bool // protected by mutex + clientRunningChan chan struct{} + clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -103,6 +106,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + + if s.clientRunning { + return nil + } + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -172,8 +180,10 @@ func (s *Server) Start() error { return nil } - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) - + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -204,12 +214,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - runningChan chan struct{}, -) { - backOff := getConnectWithBackoff(ctx) - retryStarted := false +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { + defer func() { + s.mutex.Lock() + s.clientRunning = false + s.mutex.Unlock() + }() + if s.config.DisableAutoConnect { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + log.Debugf("run client connection exited with error: %v", err) + } + log.Tracef("client connection exited") + return + } + + backOff := getConnectWithBackoff(ctx) go func() { t := time.NewTicker(24 * time.Hour) for { @@ -218,89 +238,36 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage t.Stop() return case <-t.C: - if retryStarted { - - mgmtState := statusRecorder.GetManagementState() - signalState := statusRecorder.GetSignalState() - if mgmtState.Connected && signalState.Connected { - log.Tracef("resetting status") - retryStarted = false - } else { - log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) - } + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + backOff.Reset() + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) } } } }() runOperation := func() error { - log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - - err := s.connectClient.Run(runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + return err } - if config.DisableAutoConnect { - return backoff.Permanent(err) - } - - if !retryStarted { - retryStarted = true - backOff.Reset() - } - - log.Tracef("client connection exited") - return fmt.Errorf("client connection exited") + log.Tracef("client connection exited gracefully, do not need to retry") + return nil } - err := backoff.Retry(runOperation, backOff) - if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { - log.Errorf("received an error when trying to connect: %v", err) - } else { - log.Tracef("retry canceled") - } -} - -// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries -func getConnectWithBackoff(ctx context.Context) backoff.BackOff { - initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) - maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) - maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) - multiplier := defaultRetryMultiplier - - if envValue := os.Getenv(retryMultiplierVar); envValue != "" { - // parse the multiplier from the environment variable string value to float64 - value, err := strconv.ParseFloat(envValue, 64) - if err != nil { - log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) - } else { - multiplier = value - } + if err := backoff.Retry(runOperation, backOff); err != nil { + log.Errorf("operation failed: %v", err) } - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: initialInterval, - RandomizationFactor: 1, - Multiplier: multiplier, - MaxInterval: maxInterval, - MaxElapsedTime: maxElapsedTime, // 14 days - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// parseEnvDuration parses the environment variable and returns the duration -func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { - if envValue := os.Getenv(envVar); envValue != "" { - if duration, err := time.ParseDuration(envValue); err == nil { - return duration - } - log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + if giveUpChan != nil { + close(giveUpChan) } - return defaultDuration } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -419,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(s.rootCtx) + ctx, cancel := context.WithCancel(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -429,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(ctx) + state := internal.CtxGetState(s.rootCtx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -646,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() + if s.clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return nil, err + } + if status == internal.StatusNeedsLogin { + s.actCancel() + } + s.mutex.Unlock() + + return s.waitForUp(callerCtx) + } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -661,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } + if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but . + // it should be nil here, but in case it isn't we cancel it. if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) - md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -713,23 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + return s.waitForUp(callerCtx) +} + +// todo: handle potential race conditions +func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) - for { - select { - case <-runningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() - } + select { + case <-s.clientGiveUpChan: + return nil, fmt.Errorf("client gave up to connect") + case <-s.clientRunningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } @@ -1003,12 +992,46 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - s.mutex.Lock() - defer s.mutex.Unlock() + clientRunning := s.clientRunning + s.mutex.Unlock() + + if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + return nil, err + } + + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + loop: + for { + select { + case <-s.clientGiveUpChan: + ticker.Stop() + break loop + case <-s.clientRunningChan: + ticker.Stop() + break loop + case <-ticker.C: + status, err := state.Status() + if err != nil { + continue + } + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + continue + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { @@ -1127,6 +1150,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p }, nil } +// AddProfile adds a new profile to the daemon. +func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + + if msg.ProfileName == "" || msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") + } + + if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to create profile: %v", err) + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + return &proto.AddProfileResponse{}, nil +} + +// RemoveProfile removes a profile from the daemon. +func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { + return nil, err + } + + if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { + log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) + } + + if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to remove profile: %v", err) + return nil, fmt.Errorf("failed to remove profile: %w", err) + } + + return &proto.RemoveProfileResponse{}, nil +} + +// ListProfiles lists all profiles in the daemon. +func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") + } + + profiles, err := s.profileManager.ListProfiles(msg.Username) + if err != nil { + log.Errorf("failed to list profiles: %v", err) + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + response := &proto.ListProfilesResponse{ + Profiles: make([]*proto.Profile, len(profiles)), + } + for i, profile := range profiles { + response.Profiles[i] = &proto.Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + } + } + + return response, nil +} + +// GetActiveProfile returns the active profile in the daemon. +func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + activeProfile, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + return &proto.GetActiveProfileResponse{ + ProfileName: activeProfile.Name, + Username: activeProfile.Username, + }, nil +} + +// GetFeatures returns the features supported by the daemon. +func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + features := &proto.GetFeaturesResponse{ + DisableProfiles: s.checkProfilesDisabled(), + DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + } + + return features, nil +} + +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { + log.Tracef("running client connection") + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) + if err := s.connectClient.Run(runningChan); err != nil { + return err + } + return nil +} + +func (s *Server) checkProfilesDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.profilesDisabled { + return true + } + + return false +} + +func (s *Server) checkUpdateSettingsDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.updateSettingsDisabled { + return true + } + + return false +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() @@ -1138,6 +1289,45 @@ func (s *Server) onSessionExpire() { } } +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus := proto.FullStatus{ ManagementState: &proto.ManagementState{}, @@ -1252,121 +1442,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// AddProfile adds a new profile to the daemon. -func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.checkProfilesDisabled() { - return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) - } - - if msg.ProfileName == "" || msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") - } - - if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to create profile: %v", err) - return nil, fmt.Errorf("failed to create profile: %w", err) - } - - return &proto.AddProfileResponse{}, nil -} - -// RemoveProfile removes a profile from the daemon. -func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { - return nil, err - } - - if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { - log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) - } - - if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to remove profile: %v", err) - return nil, fmt.Errorf("failed to remove profile: %w", err) - } - - return &proto.RemoveProfileResponse{}, nil -} - -// ListProfiles lists all profiles in the daemon. -func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") - } - - profiles, err := s.profileManager.ListProfiles(msg.Username) - if err != nil { - log.Errorf("failed to list profiles: %v", err) - return nil, fmt.Errorf("failed to list profiles: %w", err) - } - - response := &proto.ListProfilesResponse{ - Profiles: make([]*proto.Profile, len(profiles)), - } - for i, profile := range profiles { - response.Profiles[i] = &proto.Profile{ - Name: profile.Name, - IsActive: profile.IsActive, - } - } - - return response, nil -} - -// GetActiveProfile returns the active profile in the daemon. -func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - activeProfile, err := s.profileManager.GetActiveProfileState() - if err != nil { - log.Errorf("failed to get active profile state: %v", err) - return nil, fmt.Errorf("failed to get active profile state: %w", err) - } - - return &proto.GetActiveProfileResponse{ - ProfileName: activeProfile.Name, - Username: activeProfile.Username, - }, nil -} - -// GetFeatures returns the features supported by the daemon. -func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - features := &proto.GetFeaturesResponse{ - DisableProfiles: s.checkProfilesDisabled(), - DisableUpdateSettings: s.checkUpdateSettingsDisabled(), - } - - return features, nil -} - -func (s *Server) checkProfilesDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.profilesDisabled { - return true - } - - return false -} - -func (s *Server) checkUpdateSettingsDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.updateSettingsDisabled { - return true - } - - return false -} diff --git a/client/server/server_test.go b/client/server/server_test.go index 87889cbce..e0a4805f6 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,22 +10,26 @@ import ( "time" "github.com/golang/mock/gomock" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + + "github.com/netbirdio/management-integrations/integrations" + + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" - "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -104,7 +108,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -133,8 +137,12 @@ func TestServer_Up(t *testing.T) { profName := "default" + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: u.String(), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -152,16 +160,9 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) - err = s.Start() require.NoError(t, err) - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - s.config = &profilemanager.Config{ - ManagementURL: u, - } - upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -170,6 +171,7 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) + log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } @@ -315,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/ssh/client.go b/client/ssh/client.go index 2dc70e8fc..afba347f8 100644 --- a/client/ssh/client.go +++ b/client/ssh/client.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/login.go b/client/ssh/login.go index d1d56ceb0..cb2615e55 100644 --- a/client/ssh/login.go +++ b/client/ssh/login.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server.go b/client/ssh/server.go index 1f2001d0f..8c5db2547 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go index cc080ffdb..76f43fd4e 100644 --- a/client/ssh/server_mock.go +++ b/client/ssh/server_mock.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import "context" diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go index 5caca1834..1f310c2bb 100644 --- a/client/ssh/server_test.go +++ b/client/ssh/server_test.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/ssh_js.go b/client/ssh/ssh_js.go new file mode 100644 index 000000000..8cea88702 --- /dev/null +++ b/client/ssh/ssh_js.go @@ -0,0 +1,137 @@ +package ssh + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "strings" + + "golang.org/x/crypto/ssh" +) + +var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment") + +// Server is a dummy SSH server interface for WASM. +type Server interface { + Start() error + Stop() error + EnableSSH(enabled bool) + AddAuthorizedKey(peer string, key string) error + RemoveAuthorizedKey(key string) +} + +type dummyServer struct{} + +func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { + return &dummyServer{}, nil +} + +func NewServer(addr string) Server { + return &dummyServer{} +} + +func (s *dummyServer) Start() error { + return ErrSSHNotSupported +} + +func (s *dummyServer) Stop() error { + return nil +} + +func (s *dummyServer) EnableSSH(enabled bool) { +} + +func (s *dummyServer) AddAuthorizedKey(peer string, key string) error { + return nil +} + +func (s *dummyServer) RemoveAuthorizedKey(key string) { +} + +type Client struct{} + +func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) { + return nil, ErrSSHNotSupported +} + +func (c *Client) Close() error { + return nil +} + +func (c *Client) Run(command []string) error { + return ErrSSHNotSupported +} + +type SessionRecorder struct{} + +func NewSessionRecorder() *SessionRecorder { + return &SessionRecorder{} +} + +func (r *SessionRecorder) Record(session string, data []byte) { +} + +func GetUserShell() string { + return "/bin/sh" +} + +func LookupUserInfo(username string) (string, string, error) { + return "", "", ErrSSHNotSupported +} + +const DefaultSSHPort = 44338 + +const ED25519 = "ed25519" + +func isRoot() bool { + return false +} + +func GeneratePrivateKey(keyType string) ([]byte, error) { + if keyType != ED25519 { + return nil, errors.New("only ED25519 keys are supported in WASM") + } + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, err + } + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + } + + pemBytes := pem.EncodeToMemory(pemBlock) + return pemBytes, nil +} + +func GeneratePublicKey(privateKey []byte) ([]byte, error) { + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + block, _ := pem.Decode(privateKey) + if block != nil { + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + signer, err = ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + } else { + return nil, err + } + } + + pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey()) + return []byte(strings.TrimSpace(string(pubKeyBytes))), nil +} diff --git a/client/ssh/util.go b/client/ssh/util.go index cf5f1396e..a54a609bc 100644 --- a/client/ssh/util.go +++ b/client/ssh/util.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/system/info.go b/client/system/info.go index ceb1682f3..a180be4c0 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/shared/management/proto" @@ -172,6 +173,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { + log.Debugf("gathering system information with checks: %d", len(checks)) processCheckPaths := make([]string, 0) for _, check := range checks { processCheckPaths = append(processCheckPaths, check.GetFiles()...) @@ -181,9 +183,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro if err != nil { return nil, err } + log.Debugf("gathering process check information completed") info := GetInfo(ctx) info.Files = files + log.Debugf("all system information gathered successfully") return info, nil } diff --git a/client/system/info_js.go b/client/system/info_js.go new file mode 100644 index 000000000..994d439a7 --- /dev/null +++ b/client/system/info_js.go @@ -0,0 +1,231 @@ +package system + +import ( + "context" + "runtime" + "strings" + "syscall/js" + + "github.com/netbirdio/netbird/version" +) + +// UpdateStaticInfoAsync is a no-op on JS as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + +// GetInfo retrieves system information for WASM environment +func GetInfo(_ context.Context) *Info { + info := &Info{ + GoOS: runtime.GOOS, + Kernel: runtime.GOARCH, + KernelVersion: runtime.GOARCH, + Platform: runtime.GOARCH, + OS: runtime.GOARCH, + Hostname: "wasm-client", + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + } + + collectBrowserInfo(info) + collectLocationInfo(info) + collectSystemInfo(info) + return info +} + +func collectBrowserInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + collectUserAgent(info, navigator) + collectPlatform(info, navigator) + collectCPUInfo(info, navigator) +} + +func collectUserAgent(info *Info, navigator js.Value) { + ua := navigator.Get("userAgent") + if ua.IsUndefined() { + return + } + + userAgent := ua.String() + os, osVersion := parseOSFromUserAgent(userAgent) + if os != "" { + info.OS = os + } + if osVersion != "" { + info.OSVersion = osVersion + } +} + +func collectPlatform(info *Info, navigator js.Value) { + // Try regular platform property + if plat := navigator.Get("platform"); !plat.IsUndefined() { + if platStr := plat.String(); platStr != "" { + info.Platform = platStr + } + } + + // Try newer userAgentData API for more accurate platform + userAgentData := navigator.Get("userAgentData") + if userAgentData.IsUndefined() { + return + } + + platformInfo := userAgentData.Get("platform") + if !platformInfo.IsUndefined() { + if platStr := platformInfo.String(); platStr != "" { + info.Platform = platStr + } + } +} + +func collectCPUInfo(info *Info, navigator js.Value) { + hardwareConcurrency := navigator.Get("hardwareConcurrency") + if !hardwareConcurrency.IsUndefined() { + info.CPUs = hardwareConcurrency.Int() + } +} + +func collectLocationInfo(info *Info) { + location := js.Global().Get("location") + if location.IsUndefined() { + return + } + + if host := location.Get("hostname"); !host.IsUndefined() { + hostnameStr := host.String() + if hostnameStr != "" && hostnameStr != "localhost" { + info.Hostname = hostnameStr + } + } +} + +func checkFileAndProcess(_ []string) ([]File, error) { + return []File{}, nil +} + +func collectSystemInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + if vendor := navigator.Get("vendor"); !vendor.IsUndefined() { + info.SystemManufacturer = vendor.String() + } + + if product := navigator.Get("product"); !product.IsUndefined() { + info.SystemProductName = product.String() + } + + if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() { + ua := userAgent.String() + info.Environment = detectEnvironmentFromUA(ua) + } +} + +func parseOSFromUserAgent(userAgent string) (string, string) { + if userAgent == "" { + return "", "" + } + + switch { + case strings.Contains(userAgent, "Windows NT"): + return parseWindowsVersion(userAgent) + case strings.Contains(userAgent, "Mac OS X"): + return parseMacOSVersion(userAgent) + case strings.Contains(userAgent, "FreeBSD"): + return "FreeBSD", "" + case strings.Contains(userAgent, "OpenBSD"): + return "OpenBSD", "" + case strings.Contains(userAgent, "NetBSD"): + return "NetBSD", "" + case strings.Contains(userAgent, "Linux"): + return parseLinuxVersion(userAgent) + case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"): + return parseiOSVersion(userAgent) + case strings.Contains(userAgent, "CrOS"): + return "ChromeOS", "" + default: + return "", "" + } +} + +func parseWindowsVersion(userAgent string) (string, string) { + switch { + case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"): + return "Windows", "10/11" + case strings.Contains(userAgent, "Windows NT 10.0"): + return "Windows", "10" + case strings.Contains(userAgent, "Windows NT 6.3"): + return "Windows", "8.1" + case strings.Contains(userAgent, "Windows NT 6.2"): + return "Windows", "8" + case strings.Contains(userAgent, "Windows NT 6.1"): + return "Windows", "7" + default: + return "Windows", "Unknown" + } +} + +func parseMacOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "Mac OS X ") + if idx == -1 { + return "macOS", "Unknown" + } + + versionStart := idx + len("Mac OS X ") + versionEnd := strings.Index(userAgent[versionStart:], ")") + if versionEnd <= 0 { + return "macOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "macOS", ver +} + +func parseLinuxVersion(userAgent string) (string, string) { + if strings.Contains(userAgent, "Android") { + return "Android", extractAndroidVersion(userAgent) + } + if strings.Contains(userAgent, "Ubuntu") { + return "Ubuntu", "" + } + return "Linux", "" +} + +func parseiOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "OS ") + if idx == -1 { + return "iOS", "Unknown" + } + + versionStart := idx + 3 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd <= 0 { + return "iOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "iOS", ver +} + +func extractAndroidVersion(userAgent string) string { + if idx := strings.Index(userAgent, "Android "); idx != -1 { + versionStart := idx + len("Android ") + versionEnd := strings.IndexAny(userAgent[versionStart:], ";)") + if versionEnd > 0 { + return userAgent[versionStart : versionStart+versionEnd] + } + } + return "Unknown" +} + +func detectEnvironmentFromUA(_ string) Environment { + return Environment{} +} diff --git a/client/system/info_windows.go b/client/system/info_windows.go index e67356f57..d7f8f30aa 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -48,6 +48,5 @@ func GetInfo(ctx context.Context) *Info { gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) - return gio } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 09c76d8c0..66e150b7d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -595,27 +595,28 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + log.Errorf("get service status: %v", err) + dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) return } - } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + return + } + } + }() } }, OnCancel: func() { diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go new file mode 100644 index 000000000..d542e2739 --- /dev/null +++ b/client/wasm/cmd/main.go @@ -0,0 +1,245 @@ +//go:build js + +package main + +import ( + "context" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + + netbird "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/client/wasm/internal/http" + "github.com/netbirdio/netbird/client/wasm/internal/rdp" + "github.com/netbirdio/netbird/client/wasm/internal/ssh" + "github.com/netbirdio/netbird/util" +) + +const ( + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" +) + +func main() { + js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor)) + + select {} +} + +func startClient(ctx context.Context, nbClient *netbird.Client) error { + log.Info("Starting NetBird client...") + if err := nbClient.Start(ctx); err != nil { + return err + } + log.Info("NetBird client started successfully") + return nil +} + +// parseClientOptions extracts NetBird options from JavaScript object +func parseClientOptions(jsOptions js.Value) (netbird.Options, error) { + options := netbird.Options{ + DeviceName: "dashboard-client", + LogLevel: defaultLogLevel, + } + + if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() { + options.JWTToken = jwtToken.String() + } + + if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() { + options.SetupKey = setupKey.String() + } + + if privateKey := jsOptions.Get("privateKey"); !privateKey.IsNull() && !privateKey.IsUndefined() { + options.PrivateKey = privateKey.String() + } + + if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() { + mgmtURLStr := mgmtURL.String() + if mgmtURLStr != "" { + options.ManagementURL = mgmtURLStr + } + } + + if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() { + options.LogLevel = logLevel.String() + } + + if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() { + options.DeviceName = deviceName.String() + } + + return options, nil +} + +// createStartMethod creates the start method for the client +func createStartMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout) + defer cancel() + + if err := startClient(ctx, client); err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createStopMethod creates the stop method for the client +func createStopMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout) + defer cancel() + + if err := client.Stop(ctx); err != nil { + log.Errorf("Error stopping client: %v", err) + reject.Invoke(js.ValueOf(err.Error())) + return + } + + log.Info("NetBird client stopped") + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createSSHMethod creates the SSH connection method +func createSSHMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + username := "root" + if len(args) > 2 && args[2].String() != "" { + username = args[2].String() + } + + return createPromise(func(resolve, reject js.Value) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username); err != nil { + reject.Invoke(err.Error()) + return + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + reject.Invoke(err.Error()) + return + } + + jsInterface := ssh.CreateJSInterface(sshClient) + resolve.Invoke(jsInterface) + }) + }) +} + +// createProxyRequestMethod creates the proxyRequest method +func createProxyRequestMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: request details required") + } + + request := args[0] + + return createPromise(func(resolve, reject js.Value) { + response, err := http.ProxyRequest(client, request) + if err != nil { + reject.Invoke(err.Error()) + return + } + resolve.Invoke(response) + }) + }) +} + +// createRDPProxyMethod creates the RDP proxy method +func createRDPProxyMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: hostname and port required") + } + + proxy := rdp.NewRDCleanPathProxy(client) + return proxy.CreateProxy(args[0].String(), args[1].String()) + }) +} + +// createPromise is a helper to create JavaScript promises +func createPromise(handler func(resolve, reject js.Value)) js.Value { + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + go handler(resolve, reject) + + return nil + })) +} + +// createClientObject wraps the NetBird client in a JavaScript object +func createClientObject(client *netbird.Client) js.Value { + obj := make(map[string]interface{}) + + obj["start"] = createStartMethod(client) + obj["stop"] = createStopMethod(client) + obj["createSSHConnection"] = createSSHMethod(client) + obj["proxyRequest"] = createProxyRequestMethod(client) + obj["createRDPProxy"] = createRDPProxyMethod(client) + + return js.ValueOf(obj) +} + +// netBirdClientConstructor acts as a JavaScript constructor function +func netBirdClientConstructor(this js.Value, args []js.Value) any { + return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + if len(args) < 1 { + reject.Invoke(js.ValueOf("Options object required")) + return nil + } + + go func() { + options, err := parseClientOptions(args[0]) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil { + log.Warnf("Failed to initialize logging: %v", err) + } + + log.Infof("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s", + options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL) + + client, err := netbird.New(options) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err))) + return + } + + clientObj := createClientObject(client) + log.Info("NetBird client created successfully") + resolve.Invoke(clientObj) + }() + + return nil + })) +} diff --git a/client/wasm/internal/http/http.go b/client/wasm/internal/http/http.go new file mode 100644 index 000000000..cddc9e681 --- /dev/null +++ b/client/wasm/internal/http/http.go @@ -0,0 +1,100 @@ +//go:build js + +package http + +import ( + "fmt" + "io" + log "github.com/sirupsen/logrus" + "net/http" + "strings" + "syscall/js" + "time" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + httpTimeout = 30 * time.Second + maxResponseSize = 1024 * 1024 // 1MB +) + +// performRequest executes an HTTP request through NetBird and returns the response and body +func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) { + httpClient := nbClient.NewHTTPClient() + httpClient.Timeout = httpTimeout + + req, err := http.NewRequest(method, url, strings.NewReader(string(body))) + if err != nil { + return nil, nil, fmt.Errorf("create request: %w", err) + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Errorf("failed to close response body: %v", err) + } + }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, nil, fmt.Errorf("read response: %w", err) + } + + return resp, respBody, nil +} + +// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object +func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) { + url := request.Get("url").String() + if url == "" { + return js.Undefined(), fmt.Errorf("URL is required") + } + + method := "GET" + if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() { + method = strings.ToUpper(methodVal.String()) + } + + var requestBody []byte + if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() { + requestBody = []byte(bodyVal.String()) + } + + requestHeaders := make(map[string]string) + if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject { + headerKeys := js.Global().Get("Object").Call("keys", headersVal) + for i := 0; i < headerKeys.Length(); i++ { + key := headerKeys.Index(i).String() + value := headersVal.Get(key).String() + requestHeaders[key] = value + } + } + + resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody) + if err != nil { + return js.Undefined(), err + } + + result := js.Global().Get("Object").New() + result.Set("status", resp.StatusCode) + result.Set("statusText", resp.Status) + result.Set("body", string(body)) + + headers := js.Global().Get("Object").New() + for key, values := range resp.Header { + if len(values) > 0 { + headers.Set(strings.ToLower(key), values[0]) + } + } + result.Set("headers", headers) + + return result, nil +} diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go new file mode 100644 index 000000000..4a23a4bc8 --- /dev/null +++ b/client/wasm/internal/rdp/cert_validation.go @@ -0,0 +1,96 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + certValidationTimeout = 60 * time.Second +) + +func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { + if !conn.wsHandlers.Get("onCertificateRequest").Truthy() { + return false, fmt.Errorf("certificate validation handler not configured") + } + + certInfo := js.Global().Get("Object").New() + certInfo.Set("ServerAddr", conn.destination) + + certArray := js.Global().Get("Array").New() + for i, certBytes := range certChain { + uint8Array := js.Global().Get("Uint8Array").New(len(certBytes)) + js.CopyBytesToJS(uint8Array, certBytes) + certArray.SetIndex(i, uint8Array) + } + certInfo.Set("ServerCertChain", certArray) + if len(certChain) > 0 { + cert, err := x509.ParseCertificate(certChain[0]) + if err == nil { + info := js.Global().Get("Object").New() + info.Set("subject", cert.Subject.String()) + info.Set("issuer", cert.Issuer.String()) + info.Set("validFrom", cert.NotBefore.Format(time.RFC3339)) + info.Set("validTo", cert.NotAfter.Format(time.RFC3339)) + info.Set("serialNumber", cert.SerialNumber.String()) + certInfo.Set("CertificateInfo", info) + } + } + + promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) + + resultChan := make(chan bool) + errorChan := make(chan error) + + promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result := args[0].Bool() + resultChan <- result + return nil + })).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + errorChan <- fmt.Errorf("certificate validation failed") + return nil + })) + + select { + case result := <-resultChan: + if result { + log.Info("Certificate accepted by user") + } else { + log.Info("Certificate rejected by user") + } + return result, nil + case err := <-errorChan: + return false, err + case <-time.After(certValidationTimeout): + return false, fmt.Errorf("certificate validation timeout") + } +} + +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, // We'll validate manually after handshake + VerifyConnection: func(cs tls.ConnectionState) error { + var certChain [][]byte + for _, cert := range cs.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + + accepted, err := p.validateCertificateWithJS(conn, certChain) + if err != nil { + return err + } + if !accepted { + return fmt.Errorf("certificate rejected by user") + } + + return nil + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go new file mode 100644 index 000000000..8062a05cc --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -0,0 +1,271 @@ +//go:build js + +package rdp + +import ( + "context" + "crypto/tls" + "encoding/asn1" + "fmt" + "io" + "net" + "sync" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +const ( + RDCleanPathVersion = 3390 + RDCleanPathProxyHost = "rdcleanpath.proxy.local" + RDCleanPathProxyScheme = "ws" +) + +type RDCleanPathPDU struct { + Version int64 `asn1:"tag:0,explicit"` + Error []byte `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathProxy struct { + nbClient interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) + } + activeConnections map[string]*proxyConnection + destinations map[string]string + mu sync.Mutex +} + +type proxyConnection struct { + id string + destination string + rdpConn net.Conn + tlsConn *tls.Conn + wsHandlers js.Value + ctx context.Context + cancel context.CancelFunc +} + +// NewRDCleanPathProxy creates a new RDCleanPath proxy +func NewRDCleanPathProxy(client interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) +}) *RDCleanPathProxy { + return &RDCleanPathProxy{ + nbClient: client, + activeConnections: make(map[string]*proxyConnection), + } +} + +// CreateProxy creates a new proxy endpoint for the given destination +func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { + destination := fmt.Sprintf("%s:%s", hostname, port) + + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any { + resolve := args[0] + + go func() { + proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) + + p.mu.Lock() + if p.destinations == nil { + p.destinations = make(map[string]string) + } + p.destinations[proxyID] = destination + p.mu.Unlock() + + proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) + + // Register the WebSocket handler for this specific proxy + js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: requires WebSocket argument") + } + + ws := args[0] + p.HandleWebSocketConnection(ws, proxyID) + return nil + })) + + log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) + resolve.Invoke(proxyURL) + }() + + return nil + })) +} + +// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP +func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) { + p.mu.Lock() + destination := p.destinations[proxyID] + p.mu.Unlock() + + if destination == "" { + log.Errorf("No destination found for proxy ID: %s", proxyID) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + // Don't defer cancel here - it will be called by cleanupConnection + + conn := &proxyConnection{ + id: proxyID, + destination: destination, + wsHandlers: ws, + ctx: ctx, + cancel: cancel, + } + + p.mu.Lock() + p.activeConnections[proxyID] = conn + p.mu.Unlock() + + p.setupWebSocketHandlers(ws, conn) + + log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID) +} + +func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { + ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return nil + } + + data := args[0] + go p.handleWebSocketMessage(conn, data) + return nil + })) + + ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { + log.Debug("WebSocket closed by JavaScript") + conn.cancel() + return nil + })) +} + +func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { + if !data.InstanceOf(js.Global().Get("Uint8Array")) { + return + } + + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + + if conn.rdpConn != nil || conn.tlsConn != nil { + p.forwardToRDP(conn, bytes) + return + } + + var pdu RDCleanPathPDU + _, err := asn1.Unmarshal(bytes, &pdu) + if err != nil { + log.Warnf("Failed to parse RDCleanPath PDU: %v", err) + n := len(bytes) + if n > 20 { + n = 20 + } + log.Warnf("First %d bytes: %x", n, bytes[:n]) + + if len(bytes) > 0 && bytes[0] == 0x03 { + log.Debug("Received raw RDP packet instead of RDCleanPath PDU") + go p.handleDirectRDP(conn, bytes) + return + } + return + } + + go p.processRDCleanPathPDU(conn, pdu) +} + +func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) { + var writer io.Writer + var connType string + + if conn.tlsConn != nil { + writer = conn.tlsConn + connType = "TLS" + } else if conn.rdpConn != nil { + writer = conn.rdpConn + connType = "TCP" + } else { + log.Error("No RDP connection available") + return + } + + if _, err := writer.Write(bytes); err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + } +} + +func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) { + defer p.cleanupConnection(conn) + + destination := conn.destination + log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + return + } + conn.rdpConn = rdpConn + + _, err = rdpConn.Write(firstPacket) + if err != nil { + log.Errorf("Failed to write first packet: %v", err) + return + } + + response := make([]byte, 1024) + n, err := rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + return + } + + p.sendToWebSocket(conn, response[:n]) + + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") +} + +func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { + log.Debugf("Cleaning up connection %s", conn.id) + conn.cancel() + if conn.tlsConn != nil { + log.Debug("Closing TLS connection") + if err := conn.tlsConn.Close(); err != nil { + log.Debugf("Error closing TLS connection: %v", err) + } + conn.tlsConn = nil + } + if conn.rdpConn != nil { + log.Debug("Closing TCP connection") + if err := conn.rdpConn.Close(); err != nil { + log.Debugf("Error closing TCP connection: %v", err) + } + conn.rdpConn = nil + } + p.mu.Lock() + delete(p.activeConnections, conn.id) + p.mu.Unlock() +} + +func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { + if conn.wsHandlers.Get("receiveFromGo").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer")) + } else if conn.wsHandlers.Get("send").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("send", uint8Array.Get("buffer")) + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go new file mode 100644 index 000000000..010efa5ea --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -0,0 +1,251 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "encoding/asn1" + "io" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) + + if pdu.Version != RDCleanPathVersion { + p.sendRDCleanPathError(conn, "Unsupported version") + return + } + + destination := conn.destination + if pdu.Destination != "" { + destination = pdu.Destination + } + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, "Connection failed") + p.cleanupConnection(conn) + return + } + conn.rdpConn = rdpConn + + // RDP always starts with X.224 negotiation, then determines if TLS is needed + // Modern RDP (since Windows Vista/2008) typically requires TLS + // The X.224 Connection Confirm response will indicate if TLS is required + // For now, we'll attempt TLS for all connections as it's the modern default + p.setupTLSConnection(conn, pdu) +} + +func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + var x224Response []byte + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + x224Response = response[:n] + log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) + } + + tlsConfig := p.getTLSConfigWithValidation(conn) + + tlsConn := tls.Client(conn.rdpConn, tlsConfig) + conn.tlsConn = tlsConn + + if err := tlsConn.Handshake(); err != nil { + log.Errorf("TLS handshake failed: %v", err) + p.sendRDCleanPathError(conn, "TLS handshake failed") + return + } + + log.Info("TLS handshake successful") + + // Certificate validation happens during handshake via VerifyConnection callback + var certChain [][]byte + connState := tlsConn.ConnectionState() + if len(connState.PeerCertificates) > 0 { + for _, cert := range connState.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + log.Debugf("Extracted %d certificates from TLS connection", len(certChain)) + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + ServerCertChain: certChain, + } + + if len(x224Response) > 0 { + responsePDU.X224ConnectionPDU = x224Response + } + + p.sendRDCleanPathPDU(conn, responsePDU) + + log.Debug("Starting TLS forwarding") + go p.forwardConnToWS(conn, conn.tlsConn, "TLS") + go p.forwardWSToConn(conn, conn.tlsConn, "TLS") + + <-conn.ctx.Done() + log.Debug("TLS connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + X224ConnectionPDU: response[:n], + ServerAddr: conn.destination, + } + + p.sendRDCleanPathPDU(conn, responsePDU) + } else { + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + } + p.sendRDCleanPathPDU(conn, responsePDU) + } + + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + + <-conn.ctx.Done() + log.Debug("TCP connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal RDCleanPath PDU: %v", err) + return + } + + log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data)) + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { + pdu := RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: []byte(errorMsg), + } + + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { + msgChan := make(chan []byte) + errChan := make(chan error) + + handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + if len(args) < 1 { + errChan <- io.EOF + return nil + } + + data := args[0] + if data.InstanceOf(js.Global().Get("Uint8Array")) { + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + msgChan <- bytes + } + return nil + }) + defer handler.Release() + + conn.wsHandlers.Set("onceGoMessage", handler) + + select { + case msg := <-msgChan: + return msg, nil + case err := <-errChan: + return nil, err + case <-conn.ctx.Done(): + return nil, conn.ctx.Err() + } +} + +func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) { + for { + if conn.ctx.Err() != nil { + return + } + + msg, err := p.readWebSocketMessage(conn) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from WebSocket: %v", err) + } + return + } + + _, err = dst.Write(msg) + if err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + return + } + } +} + +func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) { + buffer := make([]byte, 32*1024) + + for { + if conn.ctx.Err() != nil { + return + } + + n, err := src.Read(buffer) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from %s: %v", connType, err) + } + return + } + + if n > 0 { + p.sendToWebSocket(conn, buffer[:n]) + } + } +} diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go new file mode 100644 index 000000000..ca35525eb --- /dev/null +++ b/client/wasm/internal/ssh/client.go @@ -0,0 +1,213 @@ +//go:build js + +package ssh + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + sshDialTimeout = 30 * time.Second +) + +func closeWithLog(c io.Closer, resource string) { + if c != nil { + if err := c.Close(); err != nil { + logrus.Debugf("Failed to close %s: %v", resource, err) + } + } +} + +type Client struct { + nbClient *netbird.Client + sshClient *ssh.Client + session *ssh.Session + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + mu sync.RWMutex +} + +// NewClient creates a new SSH client +func NewClient(nbClient *netbird.Client) *Client { + return &Client{ + nbClient: nbClient, + } +} + +// Connect establishes an SSH connection through NetBird network +func (c *Client) Connect(host string, port int, username string) error { + addr := fmt.Sprintf("%s:%d", host, port) + logrus.Infof("SSH: Connecting to %s as %s", addr, username) + + var authMethods []ssh.AuthMethod + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return fmt.Errorf("get NetBird config: %w", err) + } + if nbConfig.SSHKey == "" { + return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") + } + + signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + pubKey := signer.PublicKey() + logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) + + authMethods = append(authMethods, ssh.PublicKeys(signer)) + + config := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: sshDialTimeout, + } + + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) + defer cancel() + + conn, err := c.nbClient.Dial(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("dial %s: %w", addr, err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + closeWithLog(conn, "connection after handshake error") + return fmt.Errorf("SSH handshake: %w", err) + } + + c.sshClient = ssh.NewClient(sshConn, chans, reqs) + logrus.Infof("SSH: Connected to %s", addr) + + return nil +} + +// StartSession starts an SSH session with PTY +func (c *Client) StartSession(cols, rows int) error { + if c.sshClient == nil { + return fmt.Errorf("SSH client not connected") + } + + session, err := c.sshClient.NewSession() + if err != nil { + return fmt.Errorf("create session: %w", err) + } + + c.mu.Lock() + defer c.mu.Unlock() + c.session = session + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.VINTR: 3, + ssh.VQUIT: 28, + ssh.VERASE: 127, + } + + if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { + closeWithLog(session, "session after PTY error") + return fmt.Errorf("PTY request: %w", err) + } + + c.stdin, err = session.StdinPipe() + if err != nil { + closeWithLog(session, "session after stdin error") + return fmt.Errorf("get stdin: %w", err) + } + + c.stdout, err = session.StdoutPipe() + if err != nil { + closeWithLog(session, "session after stdout error") + return fmt.Errorf("get stdout: %w", err) + } + + c.stderr, err = session.StderrPipe() + if err != nil { + closeWithLog(session, "session after stderr error") + return fmt.Errorf("get stderr: %w", err) + } + + if err := session.Shell(); err != nil { + closeWithLog(session, "session after shell error") + return fmt.Errorf("start shell: %w", err) + } + + logrus.Info("SSH: Session started with PTY") + return nil +} + +// Write sends data to the SSH session +func (c *Client) Write(data []byte) (int, error) { + c.mu.RLock() + stdin := c.stdin + c.mu.RUnlock() + + if stdin == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdin.Write(data) +} + +// Read reads data from the SSH session +func (c *Client) Read(buffer []byte) (int, error) { + c.mu.RLock() + stdout := c.stdout + c.mu.RUnlock() + + if stdout == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdout.Read(buffer) +} + +// Resize updates the terminal size +func (c *Client) Resize(cols, rows int) error { + c.mu.RLock() + session := c.session + c.mu.RUnlock() + + if session == nil { + return fmt.Errorf("SSH session not started") + } + return session.WindowChange(rows, cols) +} + +// Close closes the SSH connection +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.session != nil { + closeWithLog(c.session, "SSH session") + c.session = nil + } + if c.stdin != nil { + closeWithLog(c.stdin, "stdin") + c.stdin = nil + } + c.stdout = nil + c.stderr = nil + + if c.sshClient != nil { + err := c.sshClient.Close() + c.sshClient = nil + return err + } + return nil +} diff --git a/client/wasm/internal/ssh/handlers.go b/client/wasm/internal/ssh/handlers.go new file mode 100644 index 000000000..ea64eb0aa --- /dev/null +++ b/client/wasm/internal/ssh/handlers.go @@ -0,0 +1,78 @@ +//go:build js + +package ssh + +import ( + "io" + "syscall/js" + + "github.com/sirupsen/logrus" +) + +// CreateJSInterface creates a JavaScript interface for the SSH client +func CreateJSInterface(client *Client) js.Value { + jsInterface := js.Global().Get("Object").Call("create", js.Null()) + + jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf(false) + } + + data := args[0] + var bytes []byte + + if data.Type() == js.TypeString { + bytes = []byte(data.String()) + } else { + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes = make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + } + + _, err := client.Write(bytes) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf(false) + } + cols := args[0].Int() + rows := args[1].Int() + err := client.Resize(cols, rows) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { + client.Close() + return js.Undefined() + })) + + go readLoop(client, jsInterface) + + return jsInterface +} + +func readLoop(client *Client, jsInterface js.Value) { + buffer := make([]byte, 4096) + for { + n, err := client.Read(buffer) + if err != nil { + if err != io.EOF { + logrus.Debugf("SSH read error: %v", err) + } + if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() { + onclose.Invoke() + } + client.Close() + return + } + + if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() { + uint8Array := js.Global().Get("Uint8Array").New(n) + js.CopyBytesToJS(uint8Array, buffer[:n]) + ondata.Invoke(uint8Array) + } + } +} diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go new file mode 100644 index 000000000..4868ba30a --- /dev/null +++ b/client/wasm/internal/ssh/key.go @@ -0,0 +1,50 @@ +//go:build js + +package ssh + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "strings" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format +func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { + keyStr := string(keyPEM) + if !strings.Contains(keyStr, "-----BEGIN") { + keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") + } + + signer, err := ssh.ParsePrivateKey(keyPEM) + if err == nil { + return signer, nil + } + logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) + + block, _ := pem.Decode(keyPEM) + if block == nil { + keyPreview := string(keyPEM) + if len(keyPreview) > 100 { + keyPreview = keyPreview[:100] + } + return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) + if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(rsaKey) + } + if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(ecKey) + } + return nil, fmt.Errorf("parse private key: %w", err) + } + + return ssh.NewSignerFromKey(key) +} diff --git a/encryption/route53.go b/encryption/route53.go index 3c81ab103..48c7a3a1b 100644 --- a/encryption/route53.go +++ b/encryption/route53.go @@ -1,3 +1,5 @@ +//go:build !js + package encryption import ( diff --git a/flow/client/client.go b/flow/client/client.go index 949824065..318fcfe1e 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -20,9 +20,10 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" - nbgrpc "github.com/netbirdio/netbird/util/grpc" + "github.com/netbirdio/netbird/util/wsproxy" ) type GRPCClient struct { @@ -38,7 +39,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl return nil, fmt.Errorf("parsing url: %w", err) } var opts []grpc.DialOption - if parsedURL.Scheme == "https" { + tlsEnabled := parsedURL.Scheme == "https" + if tlsEnabled { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -53,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(), + nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/go.mod b/go.mod index 70e52875f..a1560b409 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.13 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/lib/v4 v4.2.0 @@ -102,6 +102,7 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a + golang.org/x/mod v0.25.0 golang.org/x/net v0.42.0 golang.org/x/oauth2 v0.28.0 golang.org/x/sync v0.16.0 @@ -243,7 +244,6 @@ require ( go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect - golang.org/x/mod v0.25.0 // indirect golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.34.0 // indirect @@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 +replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index 3fdef5d08..13838b82d 100644 --- a/go.sum +++ b/go.sum @@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= @@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= -github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 08749a4f7..fb01e6867 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -45,6 +45,9 @@ services: - $SIGNAL_VOLUMENAME:/var/lib/netbird labels: - traefik.enable=true + - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) + - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal + - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c @@ -87,7 +90,9 @@ services: - traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`) - traefik.http.routers.netbird-api.service=netbird-api - traefik.http.services.netbird-api.loadbalancer.server.port=33073 - + - traefik.http.routers.netbird-wsproxy-mgmt.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/management`) + - traefik.http.routers.netbird-wsproxy-mgmt.service=netbird-wsproxy-mgmt + - traefik.http.services.netbird-wsproxy-mgmt.loadbalancer.server.port=33073 - traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`) - traefik.http.routers.netbird-management.service=netbird-management - traefik.http.services.netbird-management.loadbalancer.server.port=33073 diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2d7c65cbe..be9662345 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -328,6 +328,45 @@ delete_auto_service_user() { echo "$PARSED_RESPONSE" } +delete_default_zitadel_admin() { + INSTANCE_URL=$1 + PAT=$2 + + # Search for the default zitadel-admin user + RESPONSE=$( + curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + -d '{ + "queries": [ + { + "userNameQuery": { + "userName": "zitadel-admin@", + "method": "TEXT_QUERY_METHOD_STARTS_WITH" + } + } + ] + }' + ) + + DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty') + + if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then + echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID" + + RESPONSE=$( + curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + ) + PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"') + handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE" + + else + echo "Default zitadel-admin user not found: $RESPONSE" + fi +} + init_zitadel() { echo -e "\nInitializing Zitadel with NetBird's applications\n" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" @@ -346,6 +385,9 @@ init_zitadel() { echo -n "Waiting for Zitadel to become ready " wait_api "$INSTANCE_URL" "$PAT" + echo "Deleting default zitadel-admin user..." + delete_default_zitadel_admin "$INSTANCE_URL" "$PAT" + # create the zitadel project echo "Creating new zitadel project" PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") @@ -579,9 +621,11 @@ renderCaddyfile() { # relay reverse_proxy /relay* relay:80 # Signal + reverse_proxy /ws-proxy/signal* signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management reverse_proxy /api/* management:80 + reverse_proxy /ws-proxy/management* management:80 reverse_proxy /management.ManagementService/* h2c://management:80 # Zitadel reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080 diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index f7fa4a9d0..46cb195e7 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -20,6 +20,10 @@ upstream management { # insert the grpc+http port of your management container here server 127.0.0.1:8012; } +upstream relay { + # insert the port of your relay container here + server 127.0.0.1:33080; +} server { # HTTP server config @@ -52,6 +56,14 @@ server { location / { proxy_pass http://dashboard; } + # Proxy Signal wsproxy endpoint + location /ws-proxy/signal { + proxy_pass http://signal; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; + } # Proxy Signal location /signalexchange.SignalExchange/ { grpc_pass grpc://signal; @@ -64,6 +76,14 @@ server { location /api { proxy_pass http://management; } + # Proxy Management wsproxy endpoint + location /ws-proxy/management { + proxy_pass http://management; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; + } # Proxy Management grpc endpoint location /management.ManagementService/ { grpc_pass grpc://management; @@ -72,6 +92,14 @@ server { grpc_send_timeout 1d; grpc_socket_keepalive on; } + # Proxy Relay + location /relay { + proxy_pass http://relay; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; + } ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem; diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 984a56a39..ddd81daa2 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -10,6 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { @@ -56,8 +58,8 @@ func (s *BaseServer) AuthManager() auth.Manager { }) } -func (s *BaseServer) EphemeralManager() *server.EphemeralManager { - return Create(s, func() *server.EphemeralManager { - return server.NewEphemeralManager(s.Store(), s.AccountManager()) +func (s *BaseServer) EphemeralManager() ephemeral.Manager { + return Create(s, func() ephemeral.Manager { + return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 70f0f93a9..daec4ef6f 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -65,6 +65,10 @@ func (s *BaseServer) AccountManager() account.Manager { if err != nil { log.Fatalf("failed to create account manager: %v", err) } + + s.AfterInit(func(s *BaseServer) { + accountManager.SetEphemeralManager(s.EphemeralManager()) + }) return accountManager }) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index cfe4f32e1..c761a98d4 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,12 +6,14 @@ import ( "fmt" "net" "net/http" + "net/netip" "strings" "sync" "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -22,6 +24,8 @@ import ( "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" ) @@ -92,12 +96,6 @@ func (s *BaseServer) Start(ctx context.Context) error { s.PeersManager() s.GeoLocationManager() - for _, fn := range s.afterInit { - if fn != nil { - fn(s) - } - } - err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics") if err != nil { return fmt.Errorf("failed to expose metrics: %v", err) @@ -147,7 +145,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler()) + rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -176,6 +174,12 @@ func (s *BaseServer) Start(ctx context.Context) error { } } + for _, fn := range s.afterInit { + if fn != nil { + fn(s) + } + } + log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) @@ -247,13 +251,17 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) return util.DirectWriteJson(ctx, path, config) } -func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler { +func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto") - if request.ProtoMajor == 2 && grpcHeader { + switch { + case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || + strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - } else { + case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: + wsProxy.Handler().ServeHTTP(writer, request) + default: httpHandler.ServeHTTP(writer, request) } }) diff --git a/management/server/account.go b/management/server/account.go index ab9373f26..fcdab4b69 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -74,6 +75,7 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation + ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -261,6 +263,10 @@ func BuildManager( return am, nil } +func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { + am.ephemeralManager = em +} + func (am *DefaultAccountManager) startWarmup(ctx context.Context) { var initialInterval int64 intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 30fbbbc3e..a1ed9498b 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -12,6 +12,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -56,7 +57,7 @@ type Manager interface { UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) @@ -125,5 +126,6 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 81a921bf9..07d2f2383 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -66,7 +66,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account setupKey = key.Key } - _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -1048,10 +1048,10 @@ func TestAccountManager_AddPeer(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1112,10 +1112,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) return @@ -1429,10 +1429,10 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1805,11 +1805,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1861,11 +1861,11 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, @@ -1904,11 +1904,11 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2952,14 +2952,14 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Status: &nbpeer.PeerStatus{ Connected: true, LastSeen: time.Now().UTC(), }, - }) + }, false) if err != nil { t.Fatalf("expecting peer to be added, got failure %v", err) } @@ -3552,16 +3552,16 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { key2, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) require.NoError(t, err, "unable to add peer1") - peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) require.NoError(t, err, "unable to add peer2") t.Run("update peer IP successfully", func(t *testing.T) { diff --git a/management/server/dns.go b/management/server/dns.go index f6f0201d3..534f43ec6 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -6,9 +6,11 @@ import ( "sync" log "github.com/sirupsen/logrus" + "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" @@ -18,31 +20,18 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +const ( + dnsForwarderPort = 22054 + oldForwarderPort = 5353 +) + +const dnsForwarderPortMinVersion = "v0.59.0" + // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { - CustomZones sync.Map NameServerGroups sync.Map } -// GetCustomZone retrieves a cached custom zone -func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { - if c == nil { - return nil, false - } - if value, ok := c.CustomZones.Load(key); ok { - return value.(*proto.CustomZone), true - } - return nil, false -} - -// SetCustomZone stores a custom zone in the cache -func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { - if c == nil { - return - } - c.CustomZones.Store(key, value) -} - // GetNameServerGroup retrieves a cached name server group func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { if c == nil { @@ -203,23 +192,50 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return oldForwarderPort + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return oldForwarderPort + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return oldForwarderPort + } + } + + // All peers have the required version or newer + return dnsForwarderPort +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { +func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ ServiceEnable: update.ServiceEnable, CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, } for _, zone := range update.CustomZones { - cacheKey := zone.Domain - if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { - protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) - } else { - protoZone := convertToProtoCustomZone(zone) - cache.SetCustomZone(cacheKey, protoZone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } for _, nsGroup := range update.NameServerGroups { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index d58689544..83caf74ef 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/shared/management/status" @@ -281,11 +280,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false) if err != nil { return nil, err } @@ -324,13 +323,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ + account.NameServerGroups[dnsNSGroup1] = &nbdns.NameServerGroup{ ID: dnsNSGroup1, Name: "ns-group-1", - NameServers: []dns.NameServer{{ + NameServers: []nbdns.NameServer{{ IP: netip.MustParseAddr(savedPeer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, Primary: true, Enabled: true, @@ -395,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache) + toProtocolDNSConfig(testData, cache, dnsForwarderPort) } }) @@ -403,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache) + toProtocolDNSConfig(testData, cache, dnsForwarderPort) } }) } @@ -456,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache) + result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache) + result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache) + result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -474,15 +473,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Results should be different for different inputs") } - // Verify that the cache contains elements from both configs - if _, exists := cache.GetCustomZone("example.com"); !exists { - t.Errorf("Cache should contain custom zone for example.com") - } - - if _, exists := cache.GetCustomZone("example.org"); !exists { - t.Errorf("Cache should contain custom zone for example.org") - } - if _, exists := cache.GetNameServerGroup("group1"); !exists { t.Errorf("Cache should contain name server group 'group1'") } @@ -492,6 +482,107 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } } +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != dnsForwarderPort { + t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == oldForwarderPort { + t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) + } +} + func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) @@ -543,10 +634,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }() _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{ + context.Background(), account.Id, "ns-group", "ns-group", []nbdns.NameServer{{ IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, []string{"groupB"}, true, []string{}, true, userID, false, @@ -576,10 +667,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }() _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + context.Background(), account.Id, "ns-group-1", "ns-group-1", []nbdns.NameServer{{ IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, []string{"groupA"}, true, []string{}, true, userID, false, diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index f768414f0..58a8dcd8e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -22,6 +22,7 @@ import ( integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,7 +56,7 @@ type GRPCServer struct { config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager + ephemeralManager ephemeral.Manager peerLocks sync.Map authManager auth.Manager @@ -73,7 +74,7 @@ func NewServer( peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager *EphemeralManager, + ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, ) (*GRPCServer, error) { @@ -258,6 +259,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -716,13 +718,13 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } } -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), }, Checks: toProtocolChecks(ctx, checks), @@ -734,11 +736,11 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P response.NetworkMap.PeerConfig = response.PeerConfig - allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName) - response.RemotePeers = allPeers - response.NetworkMap.RemotePeers = allPeers - response.RemotePeersIsEmpty = len(allPeers) == 0 + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) @@ -810,7 +812,14 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups) + // Get all peers in the account for forwarder port computation + allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") + if err != nil { + return fmt.Errorf("get account peers: %w", err) + } + dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index af501e151..4b33495de 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -32,6 +32,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") } // NewHandler creates a new peers Handler @@ -318,6 +319,88 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } +func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + var req api.PeerTemporaryAccessRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + newPeer := &nbpeer.Peer{} + newPeer.FromAPITemporaryAccessRequest(&req) + + targetPeer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + for _, rule := range req.Rules { + protocol, portRange, err := types.ParseRuleString(rule) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + policy := &types.Policy{ + AccountID: userAuth.AccountId, + Description: "Temporary access policy for peer " + peer.Name, + Name: "Temporary access policy for peer " + peer.Name, + Enabled: true, + Rules: []*types.PolicyRule{{ + Name: "Temporary access rule", + Description: "Temporary access rule", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + SourceResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: peer.ID, + }, + DestinationResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: targetPeer.ID, + }, + Bidirectional: false, + Protocol: protocol, + PortRanges: []types.RulePortRange{portRange}, + }}, + } + + _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + } + + resp := &api.PeerTemporaryAccessResponse{ + Id: peer.ID, + Name: peer.Name, + Rules: req.Rules, + } + + util.WriteJSONObject(r.Context(), w, resp) +} + func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ba4997d22..a34d2086b 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -460,7 +461,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - ephemeralMgr := NewEphemeralManager(store, accountManager) + ephemeralMgr := manager.NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) if err != nil { return nil, nil, "", cleanup, err diff --git a/management/server/management_test.go b/management/server/management_test.go index 61dc46d87..1a5e47354 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -228,7 +229,7 @@ func startServer( peersUpdateManager, secretsManager, nil, - nil, + &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, ) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 003385eb5..d160e7269 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,6 +15,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -41,7 +42,7 @@ type MockAccountManager struct { DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) @@ -351,12 +352,14 @@ func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( ctx context.Context, + accountID string, setupKey string, userId string, peer *nbpeer.Peer, + temporary bool, ) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { - return am.AddPeerFunc(ctx, setupKey, userId, peer) + return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } @@ -972,6 +975,11 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } +// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface +func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { + // Mock implementation - does nothing +} + func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 959e7856a..6c985410c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -876,11 +876,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 294f51676..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -132,7 +132,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc res := nbtypes.Resource{ ID: resource.ID, - Type: resource.Type.String(), + Type: nbtypes.ResourceType(resource.Type.String()), } for _, groupID := range resource.GroupIDs { event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) @@ -265,7 +265,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { res := nbtypes.Resource{ ID: newResource.ID, - Type: newResource.Type.String(), + Type: nbtypes.ResourceType(newResource.Type.String()), } oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) diff --git a/management/server/peer.go b/management/server/peer.go index 2c55fc9ea..469b41991 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -450,7 +450,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -482,8 +482,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var ephemeral bool var groupsToAdd []string var allowExtraDNSLabels bool - var accountID string - var isEphemeral bool if addedByUser { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -492,10 +490,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if user.PendingApproval { return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") } - groupsToAdd = user.AutoGroups + if temporary { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create) + if err != nil { + return nil, nil, nil, status.NewPermissionValidationError(err) + } + + if !allowed { + return nil, nil, nil, status.NewPermissionDeniedError() + } + } else { + accountID = user.AccountID + groupsToAdd = user.AutoGroups + } opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser - accountID = user.AccountID } else { // Validate the setup key sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey) @@ -516,13 +525,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name allowExtraDNSLabels = sk.AllowExtraDNSLabels accountID = sk.AccountID - isEphemeral = sk.Ephemeral if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } } opEvent.AccountID = accountID + if temporary { + ephemeral = true + } + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if am.idpManager != nil { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) @@ -549,10 +561,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s SSHKey: peer.SSHKey, LastLogin: ®istrationTime, CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, + LoginExpirationEnabled: addedByUser && !temporary, Ephemeral: ephemeral, Location: peer.Location, - InactivityExpirationEnabled: addedByUser, + InactivityExpirationEnabled: addedByUser && !temporary, ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } @@ -588,7 +600,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var freeLabel string - if isEphemeral || attempt > 1 { + if ephemeral || attempt > 1 { freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) @@ -622,6 +634,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed adding peer to All group: %w", err) } + if temporary { + // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually + am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -712,7 +729,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool - var updated bool + var updated, versionChanged bool var err error var postureChecks []*posture.Checks @@ -752,7 +769,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return err } - updated = peer.UpdateMetaIfNew(sync.Meta) + updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta) if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) @@ -771,7 +788,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil, nil, nil, err } - if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { + if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -790,7 +807,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo ExtraDNSLabels: login.ExtraDNSLabels, } - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + return am.AddPeer(ctx, "", login.SetupKey, login.UserID, newPeer, false) } log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) @@ -863,7 +880,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return err } - isPeerUpdated = peer.UpdateMetaIfNew(login.Meta) + isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true @@ -877,6 +894,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer if peer.SSHKey != login.SSHKey { peer.SSHKey = login.SSHKey shouldStorePeer = true + updateRemotePeers = true } if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 { @@ -1211,6 +1229,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -1247,7 +1267,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account peerGroups := account.GetPeerGroups(p.ID) start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups)) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) @@ -1358,7 +1378,9 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI } peerGroups := account.GetPeerGroups(peerId) - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups)) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) + + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } @@ -1531,6 +1553,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID) @@ -1540,6 +1564,26 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peer.ID) + if err != nil { + return nil, err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } @@ -1555,6 +1599,9 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto FirewallRules: []*proto.FirewallRule{}, FirewallRulesIsEmpty: true, PeerConfig: toPeerConfig(peer, network, dnsDomain, settings), + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, }, }, NetworkMap: &types.NetworkMap{}, diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 6a6d1c91d..a898fd782 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -8,6 +8,7 @@ import ( "time" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Peer represents a machine connected to the network. @@ -232,21 +233,24 @@ func (p *Peer) Copy() *Peer { // UpdateMetaIfNew updates peer's system metadata if new information is provided // returns true if meta was updated, false otherwise -func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { +func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) { if meta.isEmpty() { - return false + return updated, versionChanged } + versionChanged = p.Meta.WtVersion != meta.WtVersion + // Avoid overwriting UIVersion if the update was triggered sole by the CLI client if meta.UIVersion == "" { meta.UIVersion = p.Meta.UIVersion } if p.Meta.isEqual(meta) { - return false + return updated, versionChanged } p.Meta = meta - return true + updated = true + return updated, versionChanged } // GetLastLogin returns the last login time of the peer. @@ -334,6 +338,17 @@ func (p *Peer) UpdateLastLogin() *Peer { return p } +func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) { + p.Ephemeral = true + p.Name = a.Name + p.Key = a.WgPubKey + p.Meta = PeerSystemMeta{ + Hostname: a.Name, + GoOS: "js", + OS: "js", + } +} + func (f Flags) isEqual(other Flags) bool { return f.RosenpassEnabled == other.RosenpassEnabled && f.RosenpassPermissive == other.RosenpassPermissive && diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 31c309430..42b3244ae 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -193,10 +193,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -207,10 +207,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -266,10 +266,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -280,10 +280,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -442,10 +442,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -456,10 +456,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -514,10 +514,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -530,10 +530,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Fatal(err) return @@ -702,19 +702,19 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return } - _, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) assert.NotNil(t, response) // assert peer config @@ -1212,6 +1212,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) // assert network map DNSConfig assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort) assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) // assert network map DNSConfig.CustomZones @@ -1300,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) { }, } - addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false) require.NoError(t, err) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) @@ -1422,7 +1423,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels, } - addedPeer, _, _, err := am.AddPeer(context.Background(), tc.existingSetupKeyID, "", currentPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false) if tc.expectAddPeerError { require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID) @@ -1523,7 +1524,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { SSHEnabled: false, } - _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + _, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false) require.Error(t, err) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key) @@ -1658,7 +1659,7 @@ func Test_LoginPeer(t *testing.T) { if sk.AllowExtraDNSLabels { currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels } - _, _, _, err = am.AddPeer(context.Background(), tc.setupKey, "", currentPeer) + _, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false) require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey) loginInput := types.PeerLogin{ @@ -1797,10 +1798,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1918,11 +1919,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1982,11 +1983,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2037,11 +2038,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2208,7 +2209,7 @@ func Test_AddPeer(t *testing.T) { <-start - _, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false) if err != nil { errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return @@ -2416,7 +2417,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false) require.Error(t, err) assert.Contains(t, err.Error(), "user pending approval cannot add peers") } @@ -2451,7 +2452,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false) require.NoError(t, err, "Regular user should be able to add peers") } @@ -2494,7 +2495,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false) require.NoError(t, err) // Now set the user back to pending approval after peer was created @@ -2550,7 +2551,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false) require.NoError(t, err) // Try to login with regular user diff --git a/management/server/peers/ephemeral/interface.go b/management/server/peers/ephemeral/interface.go new file mode 100644 index 000000000..a1605b3b9 --- /dev/null +++ b/management/server/peers/ephemeral/interface.go @@ -0,0 +1,14 @@ +package ephemeral + +import ( + "context" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +type Manager interface { + LoadInitialPeers(ctx context.Context) + Stop() + OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) + OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) +} diff --git a/management/server/ephemeral.go b/management/server/peers/ephemeral/manager/ephemeral.go similarity index 99% rename from management/server/ephemeral.go rename to management/server/peers/ephemeral/manager/ephemeral.go index e3cb5459a..062ba69d2 100644 --- a/management/server/ephemeral.go +++ b/management/server/peers/ephemeral/manager/ephemeral.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" diff --git a/management/server/ephemeral_test.go b/management/server/peers/ephemeral/manager/ephemeral_test.go similarity index 75% rename from management/server/ephemeral_test.go rename to management/server/peers/ephemeral/manager/ephemeral_test.go index d07b9a422..fc7525c29 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/peers/ephemeral/manager/ephemeral_test.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" @@ -7,12 +7,15 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" ) type MockStore struct { @@ -223,3 +226,57 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) store.account.Peers[p.ID] = p } } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[route.ID]*route.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, + }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, + } + + if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} diff --git a/management/server/policy.go b/management/server/policy.go index 312fd53b2..9e4b3f73a 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -151,6 +151,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return false, nil } + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err @@ -161,16 +167,34 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a } } + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } + + // TODO: Refactor to support multiple rules per policy + existingRuleIDs := make(map[string]bool) + for _, rule := range existingPolicy.Rules { + existingRuleIDs[rule.ID] = true + } + + for _, rule := range policy.Rules { + if rule.ID != "" && !existingRuleIDs[rule.ID] { + return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + } + } } else { policy.ID = xid.New().String() policy.AccountID = accountID diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 027938320..382d026c8 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2037,6 +2037,25 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) }) } +func (s *SqlStore) GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, resourceID string) ([]*types.PolicyRule, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var policyRules []*types.PolicyRule + resourceIDPattern := `%"ID":"` + resourceID + `"%` + result := tx.Where("source_resource LIKE ? OR destination_resource LIKE ?", resourceIDPattern, resourceIDPattern). + Find(&policyRules) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get policy rules for resource id from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policy rules for resource id from store") + } + + return policyRules, nil +} + // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 3c9d896b0..21b660d96 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -202,6 +202,7 @@ type Store interface { IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error + GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index 9ac2568a0..f830023c7 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -300,9 +300,12 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -998,8 +1001,20 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID) + } else { + sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID) + } else { + destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + } if rule.Bidirectional { if peerInSources { @@ -1121,6 +1136,15 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe return filteredPeers, peerInGroups } +func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + peer := a.GetPeer(resource.ID) + if peer == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peer}, resource.ID == peerID +} + // validatePostureChecksOnPeer validates the posture checks on a peer func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] @@ -1376,7 +1400,12 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } if addSourcePeers { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { allSourcePeers[pID] = struct{}{} @@ -1651,3 +1680,24 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } + +// filterZoneRecordsForPeers filters DNS records to only include peers to connect. +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) + peerIPs := make(map[string]struct{}) + + // Add peer's own IP to include its own DNS records + peerIPs[peer.IP.String()] = struct{}{} + + for _, peerToConnect := range peersToConnect { + peerIPs[peerToConnect.IP.String()] = struct{}{} + } + + for _, record := range customZone.Records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f8ab1d627..cd221b590 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,14 +2,17 @@ package types import ( "context" + "fmt" "net" "net/netip" "slices" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } + +func Test_FilterZoneRecordsForPeers(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + customZone nbdns.CustomZone + peersToConnect []*nbpeer.Peer + expectedRecords []nbdns.SimpleRecord + }{ + { + name: "empty peers to connect", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + { + name: "multiple peers multiple records match", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for i := 1; i <= 100; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + peersToConnect: func() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + peers = append(peers, &nbpeer.Peer{ + ID: fmt.Sprintf("peer%d", i), + IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + }) + } + return peers + }(), + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + { + name: "peers with multiple DNS labels", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, + {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + assert.Equal(t, len(tt.expectedRecords), len(result)) + assert.ElementsMatch(t, tt.expectedRecords, result) + }) + } +} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 17964ed1f..5e86a87c6 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,12 @@ package types +import ( + "errors" + "fmt" + "strconv" + "strings" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -134,3 +141,83 @@ func (p *Policy) SourceGroups() []string { return groupIDs } + +func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) { + rule = strings.TrimSpace(strings.ToLower(rule)) + if rule == "all" { + return PolicyRuleProtocolALL, RulePortRange{}, nil + } + if rule == "icmp" { + return PolicyRuleProtocolICMP, RulePortRange{}, nil + } + + split := strings.Split(rule, "/") + if len(split) != 2 { + return "", RulePortRange{}, errors.New("invalid rule format: expected protocol/port or protocol/port-range") + } + + protoStr := strings.TrimSpace(split[0]) + portStr := strings.TrimSpace(split[1]) + + var protocol PolicyRuleProtocolType + switch protoStr { + case "tcp": + protocol = PolicyRuleProtocolTCP + case "udp": + protocol = PolicyRuleProtocolUDP + case "icmp": + return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + default: + return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) + } + + portRange, err := parsePortRange(portStr) + if err != nil { + return "", RulePortRange{}, err + } + + return protocol, portRange, nil +} + +func parsePortRange(portStr string) (RulePortRange, error) { + if strings.Contains(portStr, "-") { + rangeParts := strings.Split(portStr, "-") + if len(rangeParts) != 2 { + return RulePortRange{}, fmt.Errorf("invalid port range %q", portStr) + } + start, err := parsePort(strings.TrimSpace(rangeParts[0])) + if err != nil { + return RulePortRange{}, err + } + end, err := parsePort(strings.TrimSpace(rangeParts[1])) + if err != nil { + return RulePortRange{}, err + } + if start > end { + return RulePortRange{}, fmt.Errorf("invalid port range: start %d > end %d", start, end) + } + return RulePortRange{Start: uint16(start), End: uint16(end)}, nil + } + + p, err := parsePort(portStr) + if err != nil { + return RulePortRange{}, err + } + + return RulePortRange{Start: uint16(p), End: uint16(p)}, nil +} + +func parsePort(portStr string) (int, error) { + + if portStr == "" { + return 0, errors.New("empty port") + } + p, err := strconv.Atoi(portStr) + if err != nil { + return 0, fmt.Errorf("invalid port %q: %w", portStr, err) + } + if p < 1 || p > 65535 { + return 0, fmt.Errorf("port out of range (1–65535): %d", p) + } + return p, nil +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go index 84d8e4b88..8347d8c03 100644 --- a/management/server/types/resource.go +++ b/management/server/types/resource.go @@ -4,9 +4,18 @@ import ( "github.com/netbirdio/netbird/shared/management/http/api" ) +type ResourceType string + +const ( + ResourceTypePeer ResourceType = "peer" + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + type Resource struct { ID string - Type string + Type ResourceType } func (r *Resource) ToAPIResponse() *api.Resource { @@ -26,5 +35,5 @@ func (r *Resource) FromAPIRequest(req *api.Resource) { } r.ID = req.Id - r.Type = string(req.Type) + r.Type = ResourceType(req.Type) } diff --git a/management/server/user.go b/management/server/user.go index 3c7c3f433..d40d33c6a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service + log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) } diff --git a/management/server/user_test.go b/management/server/user_test.go index 9638559f9..5920a2a33 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1439,10 +1439,10 @@ func TestUserAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) // updating user with linked peers should update account peers and send peer update diff --git a/release_files/install.sh b/release_files/install.sh index 856d332cb..5d5349ec4 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -130,36 +130,6 @@ repo_gpgcheck=1 EOF } -install_aur_package() { - INSTALL_PKGS="git base-devel go" - REMOVE_PKGS="" - - # Check if dependencies are installed - for PKG in $INSTALL_PKGS; do - if ! pacman -Q "$PKG" > /dev/null 2>&1; then - # Install missing package(s) - ${SUDO} pacman -S "$PKG" --noconfirm - - # Add installed package for clean up later - REMOVE_PKGS="$REMOVE_PKGS $PKG" - fi - done - - # Build package from AUR - cd /tmp && git clone https://aur.archlinux.org/netbird.git - cd netbird && makepkg -sri --noconfirm - - if ! $SKIP_UI_APP; then - cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git - cd netbird-ui && makepkg -sri --noconfirm - fi - - if [ -n "$REMOVE_PKGS" ]; then - # Clean up the installed packages - ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm - fi -} - prepare_tun_module() { # Create the necessary file structure for /dev/net/tun if [ ! -c /dev/net/tun ]; then @@ -276,12 +246,9 @@ install_netbird() { if ! $SKIP_UI_APP; then ${SUDO} rpm-ostree -y install netbird-ui fi - ;; - pacman) - ${SUDO} pacman -Syy - install_aur_package - # in-line with the docs at https://wiki.archlinux.org/title/Netbird - ${SUDO} systemctl enable --now netbird@main.service + # ensure the service is started after install + ${SUDO} netbird service install || true + ${SUDO} netbird service start || true ;; pkg) # Check if the package is already installed @@ -458,11 +425,7 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" - elif [ -x "$(command -v pacman)" ]; then - PACKAGE_MANAGER="pacman" - echo "The installation will be performed using pacman package manager" fi - else echo "Unable to determine OS type from /etc/os-release" exit 1 diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index becc10ded..d4a9f1823 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/server/config" @@ -27,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +119,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index dc26253e9..076f2532b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -17,11 +17,12 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" + "github.com/netbirdio/netbird/util/wsproxy" ) const ConnectTimeout = 10 * time.Second @@ -52,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index fbd0e184d..f4ad59052 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -511,6 +511,48 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerTemporaryAccessRequest: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + wg_pub_key: + description: Peer's WireGuard public key + type: string + example: "n0r3pL4c3h0ld3rK3y==" + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - wg_pub_key + - rules + PeerTemporaryAccessResponse: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + id: + description: Peer ID + type: string + example: chacbco6lnnbn6cg5s90 + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - id + - rules AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -1408,7 +1450,8 @@ components: allOf: - $ref: '#/components/schemas/NetworkResourceType' - type: string - example: host + enum: ["peer"] + example: peer NetworkRequest: type: object properties: @@ -2797,6 +2840,42 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/peers/{peerId}/temporary-access: + post: + summary: Create a Temporary Access Peer + description: Creates a temporary access peer that can be used to access this peer and this peer only. The temporary access peer and its access policies will be automatically deleted after it disconnects. + tags: [ Peers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: peerId + required: true + schema: + type: string + description: The unique identifier of a peer + requestBody: + description: Temporary Access Peer create request + content: + 'application/json': + schema: + $ref: '#/components/schemas/PeerTemporaryAccessRequest' + responses: + '200': + description: Temporary Access Peer response + content: + application/json: + schema: + $ref: '#/components/schemas/PeerTemporaryAccessResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers/{peerId}/ingress/ports: get: x-cloud-only: true diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3af9d6ee2..f25603a00 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -168,6 +168,7 @@ const ( const ( ResourceTypeDomain ResourceType = "domain" ResourceTypeHost ResourceType = "host" + ResourceTypePeer ResourceType = "peer" ResourceTypeSubnet ResourceType = "subnet" ) @@ -1224,6 +1225,30 @@ type PeerRequest struct { SshEnabled bool `json:"ssh_enabled"` } +// PeerTemporaryAccessRequest defines model for PeerTemporaryAccessRequest. +type PeerTemporaryAccessRequest struct { + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` + + // WgPubKey Peer's WireGuard public key + WgPubKey string `json:"wg_pub_key"` +} + +// PeerTemporaryAccessResponse defines model for PeerTemporaryAccessResponse. +type PeerTemporaryAccessResponse struct { + // Id Peer ID + Id string `json:"id"` + + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` +} + // PersonalAccessToken defines model for PersonalAccessToken. type PersonalAccessToken struct { // CreatedAt Date the token was created @@ -1952,6 +1977,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques // PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType. type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest +// PostApiPeersPeerIdTemporaryAccessJSONRequestBody defines body for PostApiPeersPeerIdTemporaryAccess for application/json ContentType. +type PostApiPeersPeerIdTemporaryAccessJSONRequestBody = PeerTemporaryAccessRequest + // PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType. type PostApiPoliciesJSONRequestBody = PolicyUpdate diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 234af2395..7ff863133 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1466,8 +1466,8 @@ type FlowConfig struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` - TokenPayload string `protobuf:"bytes,2,opt,name=tokenPayload,proto3" json:"tokenPayload,omitempty"` + Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` + TokenPayload string `protobuf:"bytes,2,opt,name=tokenPayload,proto3" json:"tokenPayload,omitempty"` TokenSignature string `protobuf:"bytes,3,opt,name=tokenSignature,proto3" json:"tokenSignature,omitempty"` Interval *durationpb.Duration `protobuf:"bytes,4,opt,name=interval,proto3" json:"interval,omitempty"` Enabled bool `protobuf:"varint,5,opt,name=enabled,proto3" json:"enabled,omitempty"` @@ -2556,6 +2556,7 @@ type DNSConfig struct { ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"` NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"` CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"` + ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` } func (x *DNSConfig) Reset() { @@ -2611,6 +2612,13 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone { return nil } +func (x *DNSConfig) GetForwarderPort() int64 { + if x != nil { + return x.ForwarderPort + } + return 0 +} + // CustomZone represents a dns.CustomZone type CustomZone struct { state protoimpl.MessageState @@ -3795,7 +3803,7 @@ var file_management_proto_rawDesc = []byte{ 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, - 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xb4, 0x01, 0x0a, 0x09, + 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, @@ -3807,157 +3815,159 @@ var file_management_proto_rawDesc = []byte{ 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, - 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, - 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, - 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, - 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, - 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, - 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, - 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, - 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, - 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, - 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, - 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, - 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, - 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, - 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, - 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, - 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, - 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, - 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, - 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, - 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, - 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, - 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, - 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, - 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, - 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, - 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, - 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, - 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, - 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, - 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, - 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, - 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, - 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, - 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, - 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, - 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, - 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, - 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, - 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, - 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, - 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, - 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, - 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, - 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, - 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, - 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, - 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, - 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, - 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, - 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, - 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, - 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x65, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, + 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, + 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, + 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, + 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, + 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, + 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, + 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, + 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, + 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, + 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, + 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, + 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, + 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, + 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, + 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, + 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, + 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, + 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, + 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, + 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, + 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, + 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, + 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, + 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, + 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, + 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, + 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, + 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, + 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, + 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, + 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, + 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, + 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, + 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, + 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, + 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, + 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, - 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, - 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, + 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, + 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, - 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, + 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, - 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, - 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, + 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 7403ab13c..5a86e4beb 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -419,6 +419,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; + int64 ForwarderPort = 4; } // CustomZone represents a dns.CustomZone diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 5dabc5742..57a98614d 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -9,11 +9,8 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client/dialer" - "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" - "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" ) @@ -296,14 +293,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - // Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues - var dialers []dialer.DialeFn - if c.mtu > 0 && c.mtu > iface.DefaultMTU { - c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) - dialers = []dialer.DialeFn{ws.Dialer{}} - } else { - dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} - } + dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) conn, err := rd.Dial() diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index b496f6a9b..967e18d79 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" quictls "github.com/netbirdio/netbird/shared/relay/tls" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index 0086b702b..d5b719f51 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Write(b []byte) (n int, err error) { - err = c.Conn.Write(c.ctx, websocket.MessageBinary, b) - return 0, err + return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b) } func (c *Conn) RemoteAddr() net.Addr { diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go new file mode 100644 index 000000000..9dfe698d0 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -0,0 +1,11 @@ +//go:build !js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + return &websocket.DialOptions{ + HTTPClient: httpClientNbDialer(), + } +} diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go new file mode 100644 index 000000000..7eac27531 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -0,0 +1,10 @@ +//go:build js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + // WASM version doesn't support HTTPClient + return &websocket.DialOptions{} +} diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 109651f5d..66fff3447 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { @@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } - opts := &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), - } + opts := createDialOptions() parsedURL, err := url.Parse(wsURL) if err != nil { diff --git a/shared/relay/client/dialers_generic.go b/shared/relay/client/dialers_generic.go new file mode 100644 index 000000000..a8ed79961 --- /dev/null +++ b/shared/relay/client/dialers_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package client + +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +// getDialers returns the list of dialers to use for connecting to the relay server. +func (c *Client) getDialers() []dialer.DialeFn { + if c.mtu > 0 && c.mtu > iface.DefaultMTU { + c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) + return []dialer.DialeFn{ws.Dialer{}} + } + return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} +} diff --git a/shared/relay/client/dialers_js.go b/shared/relay/client/dialers_js.go new file mode 100644 index 000000000..6bd0e6696 --- /dev/null +++ b/shared/relay/client/dialers_js.go @@ -0,0 +1,13 @@ +//go:build js + +package client + +import ( + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +func (c *Client) getDialers() []dialer.DialeFn { + // JS/WASM build only uses WebSocket transport + return []dialer.DialeFn{ws.Dialer{}} +} diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index a40343fb1..6220e7f6b 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -78,9 +78,10 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin tokenStore: tokenStore, mtu: mtu, serverPicker: &ServerPicker{ - TokenStore: tokenStore, - PeerID: peerID, - MTU: mtu, + TokenStore: tokenStore, + PeerID: peerID, + MTU: mtu, + ConnectionTimeout: defaultConnectionTimeout, }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go index b6c7b5e8a..39d0ba072 100644 --- a/shared/relay/client/picker.go +++ b/shared/relay/client/picker.go @@ -13,11 +13,8 @@ import ( ) const ( - maxConcurrentServers = 7 -) - -var ( - connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 + defaultConnectionTimeout = 30 * time.Second ) type connResult struct { @@ -27,14 +24,15 @@ type connResult struct { } type ServerPicker struct { - TokenStore *auth.TokenStore - ServerURLs atomic.Value - PeerID string - MTU uint16 + TokenStore *auth.TokenStore + ServerURLs atomic.Value + PeerID string + MTU uint16 + ConnectionTimeout time.Duration } func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { - ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout) defer cancel() totalServers := len(sp.ServerURLs.Load().([]string)) diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go index 28167c5ce..fb3fa7375 100644 --- a/shared/relay/client/picker_test.go +++ b/shared/relay/client/picker_test.go @@ -8,15 +8,15 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { - connectionTimeout = 5 * time.Second - + timeout := 5 * time.Second sp := ServerPicker{ - TokenStore: nil, - PeerID: "test", + TokenStore: nil, + PeerID: "test", + ConnectionTimeout: timeout, } sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) - ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) + ctx, cancel := context.WithTimeout(context.Background(), timeout+1) defer cancel() go func() { diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go new file mode 100644 index 000000000..2b584c195 --- /dev/null +++ b/shared/relay/healthcheck/env.go @@ -0,0 +1,24 @@ +package healthcheck + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" +) + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go new file mode 100644 index 000000000..2e14bb8bf --- /dev/null +++ b/shared/relay/healthcheck/env_test.go @@ -0,0 +1,36 @@ +package healthcheck + +import ( + "os" + "testing" +) + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go index b3503d5db..90f795bbe 100644 --- a/shared/relay/healthcheck/receiver.go +++ b/shared/relay/healthcheck/receiver.go @@ -7,10 +7,15 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - heartbeatTimeout = healthCheckInterval + 10*time.Second +const ( + defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second ) +type ReceiverOptions struct { + HeartbeatTimeout time.Duration + AttemptThreshold int +} + // Receiver is a healthcheck receiver // It will listen for heartbeat and check if the heartbeat is not received in a certain time // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work @@ -27,6 +32,23 @@ type Receiver struct { // NewReceiver creates a new healthcheck receiver and start the timer in the background func NewReceiver(log *log.Entry) *Receiver { + opts := ReceiverOptions{ + HeartbeatTimeout: defaultHeartbeatTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewReceiverWithOpts(log, opts) +} + +func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { + heartbeatTimeout := opts.HeartbeatTimeout + if heartbeatTimeout <= 0 { + heartbeatTimeout = defaultHeartbeatTimeout + } + attemptThreshold := opts.AttemptThreshold + if attemptThreshold <= 0 { + attemptThreshold = defaultAttemptThreshold + } + ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ @@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver { ctx: ctx, ctxCancel: ctxCancel, heartbeat: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + attemptThreshold: attemptThreshold, } - go r.waitForHealthcheck() + go r.waitForHealthcheck(heartbeatTimeout) return r } @@ -55,7 +77,7 @@ func (r *Receiver) Stop() { r.ctxCancel() } -func (r *Receiver) waitForHealthcheck() { +func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) { ticker := time.NewTicker(heartbeatTimeout) defer ticker.Stop() defer r.ctxCancel() diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go index 2794159f6..b20cc5124 100644 --- a/shared/relay/healthcheck/receiver_test.go +++ b/shared/relay/healthcheck/receiver_test.go @@ -2,31 +2,18 @@ package healthcheck import ( "context" - "fmt" - "os" - "sync" "testing" "time" log "github.com/sirupsen/logrus" ) -// Mutex to protect global variable access in tests -var testMutex sync.Mutex - func TestNewReceiver(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 5 * time.Second - testMutex.Unlock() - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 5 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) { } func TestNewReceiverNotReceive(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 1 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 1 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) { } func TestNewReceiverAck(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 2 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 2 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() r.Heartbeat() @@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - testMutex.Lock() - originalInterval := healthCheckInterval - originalTimeout := heartbeatTimeout - healthCheckInterval = 1 * time.Second - heartbeatTimeout = healthCheckInterval + 500*time.Millisecond - testMutex.Unlock() + healthCheckInterval := 1 * time.Second - defer func() { - testMutex.Lock() - healthCheckInterval = originalInterval - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := ReceiverOptions{ + HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond, + AttemptThreshold: tc.threshold, + } - receiver := NewReceiver(log.WithField("test_name", tc.name)) + receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts) - testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval if tc.resetCounterOnce { receiver.Heartbeat() - t.Logf("reset counter once") } select { @@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { } t.Fatalf("should have timed out before %s", testTimeout) } - }) } } diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go index 57b3015ec..771e94206 100644 --- a/shared/relay/healthcheck/sender.go +++ b/shared/relay/healthcheck/sender.go @@ -2,52 +2,76 @@ package healthcheck import ( "context" - "os" - "strconv" "time" log "github.com/sirupsen/logrus" ) const ( - defaultAttemptThreshold = 1 - defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" + defaultAttemptThreshold = 1 + + defaultHealthCheckInterval = 25 * time.Second + defaultHealthCheckTimeout = 20 * time.Second ) -var ( - healthCheckInterval = 25 * time.Second - healthCheckTimeout = 20 * time.Second -) +type SenderOptions struct { + HealthCheckInterval time.Duration + HealthCheckTimeout time.Duration + AttemptThreshold int +} // Sender is a healthcheck sender // It will send healthcheck signal to the receiver // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { - log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} + log *log.Entry + healthCheckInterval time.Duration + timeout time.Duration + ack chan struct{} alive bool attemptThreshold int } -// NewSender creates a new healthcheck sender -func NewSender(log *log.Entry) *Sender { +func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender { + if opts.HealthCheckInterval <= 0 { + opts.HealthCheckInterval = defaultHealthCheckInterval + } + if opts.HealthCheckTimeout <= 0 { + opts.HealthCheckTimeout = defaultHealthCheckTimeout + } + if opts.AttemptThreshold <= 0 { + opts.AttemptThreshold = defaultAttemptThreshold + } hc := &Sender{ - log: log, - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - ack: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + log: log, + healthCheckInterval: opts.HealthCheckInterval, + timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout, + ack: make(chan struct{}, 1), + attemptThreshold: opts.AttemptThreshold, } return hc } +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { + opts := SenderOptions{ + HealthCheckInterval: defaultHealthCheckInterval, + HealthCheckTimeout: defaultHealthCheckTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewSenderWithOpts(log, opts) +} + // OnHCResponse sends an acknowledgment signal to the sender func (hc *Sender) OnHCResponse() { select { @@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() { } func (hc *Sender) StartHealthCheck(ctx context.Context) { - ticker := time.NewTicker(healthCheckInterval) + ticker := time.NewTicker(hc.healthCheckInterval) defer ticker.Stop() - timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + timeoutTicker := time.NewTicker(hc.timeout) defer timeoutTicker.Stop() defer close(hc.HealthCheck) @@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { } } } - -func (hc *Sender) getTimeoutTime() time.Duration { - return healthCheckInterval + healthCheckTimeout -} - -func getAttemptThresholdFromEnv() int { - if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { - threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) - if err != nil { - log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) - return defaultAttemptThreshold - } - return int(threshold) - } - return defaultAttemptThreshold -} diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go index 23446366a..122fe0f16 100644 --- a/shared/relay/healthcheck/sender_test.go +++ b/shared/relay/healthcheck/sender_test.go @@ -2,26 +2,23 @@ package healthcheck import ( "context" - "fmt" - "os" "testing" "time" log "github.com/sirupsen/logrus" ) -func TestMain(m *testing.M) { - // override the health check interval to speed up the test - healthCheckInterval = 2 * time.Second - healthCheckTimeout = 100 * time.Millisecond - code := m.Run() - os.Exit(code) -} +var ( + testOpts = SenderOptions{ + HealthCheckInterval: 2 * time.Second, + HealthCheckTimeout: 100 * time.Millisecond, + } +) func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) select { case <-hc.Timeout: - case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond): t.Fatalf("health check is not timed out") } } func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - originalInterval := healthCheckInterval - originalTimeout := healthCheckTimeout - healthCheckInterval = 1 * time.Second - healthCheckTimeout = 500 * time.Millisecond - - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := SenderOptions{ + HealthCheckInterval: 1 * time.Second, + HealthCheckTimeout: 500 * time.Millisecond, + AttemptThreshold: tc.threshold, + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sender := NewSender(log.WithField("test_name", tc.name)) + sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts) senderExit := make(chan struct{}) go func() { sender.StartHealthCheck(ctx) @@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { } }() - testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval select { case <-sender.Timeout: @@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout }) } } - -//nolint:tenv -func TestGetAttemptThresholdFromEnv(t *testing.T) { - tests := []struct { - name string - envValue string - expected int - }{ - {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, - {"Custom attempt threshold when env is set to a valid integer", "3", 3}, - {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.envValue == "" { - os.Unsetenv(defaultAttemptThresholdEnv) - } else { - os.Setenv(defaultAttemptThresholdEnv, tt.envValue) - } - - result := getAttemptThresholdFromEnv() - if result != tt.expected { - t.Fatalf("Expected %d, got %d", tt.expected, result) - } - - os.Unsetenv(defaultAttemptThresholdEnv) - }) - } -} diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 82ab678f4..31f3372c0 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -16,10 +16,11 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" + "github.com/netbirdio/netbird/util/wsproxy" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +58,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index d4fedc492..bc2d4d1be 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -93,7 +93,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err) } var sockErr error @@ -102,7 +102,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err) } } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 1d76fa4e4..696c44723 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,14 +8,16 @@ import ( "fmt" "net" "net/http" - // nolint:gosec _ "net/http/pprof" - "strings" + "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "github.com/netbirdio/netbird/signal/metrics" @@ -23,6 +25,8 @@ import ( "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/server" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" log "github.com/sirupsen/logrus" @@ -32,6 +36,8 @@ import ( "google.golang.org/grpc/keepalive" ) +const legacyGRPCPort = 10000 + var ( signalPort int metricsPort int @@ -113,7 +119,7 @@ var ( } proto.RegisterSignalExchangeServer(grpcServer, srv) - grpcRootHandler := grpcHandlerFunc(grpcServer) + grpcRootHandler := grpcHandlerFunc(grpcServer, metricsServer.Meter) if certManager != nil { startServerWithCertManager(certManager, grpcRootHandler) @@ -123,19 +129,30 @@ var ( var grpcListener net.Listener var httpListener net.Listener - // If certManager is configured and signalPort == 443, then the gRPC server has already been started - if certManager == nil || signalPort != 443 { - grpcListener, err = serveGRPC(grpcServer, signalPort) + // Start the main server - always serve HTTP with WebSocket proxy support + // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager + if certManager == nil { + // Without TLS, serve plain HTTP + httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { return err } - log.Infof("running gRPC server: %s", grpcListener.Addr().String()) + log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) + } else if signalPort != 443 { + // With TLS but not on port 443, serve HTTPS + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + if err != nil { + return err + } + log.Infof("running HTTPS server with WebSocket proxy: %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) } - if signalPort != 10000 { + if signalPort != legacyGRPCPort { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. - compatListener, err = serveGRPC(grpcServer, 10000) + compatListener, err = serveGRPC(grpcServer, legacyGRPCPort) if err != nil { return err } @@ -236,11 +253,14 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } } -func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { +func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto") - if r.ProtoMajor == 2 && grpcHeader { + switch { + case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent: + wsProxy.Handler().ServeHTTP(w, r) + default: grpcServer.ServeHTTP(w, r) } }) @@ -257,7 +277,11 @@ func notifyStop(msg string) { func serveHTTP(httpListener net.Listener, handler http.Handler) { go func() { - err := http.Serve(httpListener, handler) + // Use h2c to support HTTP/2 without TLS (needed for gRPC) + h1s := &http.Server{ + Handler: h2c.NewHandler(handler, &http2.Server{}), + } + err := h1s.Serve(httpListener) if err != nil { notifyStop(fmt.Sprintf("failed running HTTP server %v", err)) } diff --git a/util/net/conn.go b/util/net/conn.go deleted file mode 100644 index 26693f841..000000000 --- a/util/net/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !ios - -package net - -import ( - "net" - - log "github.com/sirupsen/logrus" -) - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} diff --git a/util/net/dial.go b/util/net/dial.go deleted file mode 100644 index 595311492..000000000 --- a/util/net/dial.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build !ios - -package net - -import ( - "fmt" - "net" - - log "github.com/sirupsen/logrus" -) - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go deleted file mode 100644 index 1659b6220..000000000 --- a/util/net/dialer_dial.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHooks removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - log.Debugf("Dialing %s %s", network, address) - - if CustomRoutingDisabled() { - return d.Dialer.DialContext(ctx, network, address) - } - - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := callDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go deleted file mode 100644 index 8c57ebbaa..000000000 --- a/util/net/dialer_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (d *Dialer) init() { - // implemented on Linux and Android only -} diff --git a/util/net/env_generic.go b/util/net/env_generic.go deleted file mode 100644 index 6d142a838..000000000 --- a/util/net/env_generic.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !linux || android - -package net - -func Init() { - // nothing to do on non-linux -} - -func AdvancedRouting() bool { - // non-linux currently doesn't support advanced routing - return false -} diff --git a/util/net/listen.go b/util/net/listen.go deleted file mode 100644 index 3ae8a9435..000000000 --- a/util/net/listen.go +++ /dev/null @@ -1,37 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go deleted file mode 100644 index 80f6f7f1a..000000000 --- a/util/net/listener_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (l *ListenerConfig) init() { - // implemented on Linux and Android only -} diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go deleted file mode 100644 index 4060ab49a..000000000 --- a/util/net/listener_listen.go +++ /dev/null @@ -1,205 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. -type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc - listenerAddressRemoveHooksMutex sync.RWMutex - listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. -func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) -} - -// RemoveListenerHooks removes all listener hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil - - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - if CustomRoutingDisabled() { - return l.ListenConfig.ListenPacket(ctx, network, address) - } - - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -// RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *PacketConn) RemoveAddress(addr string) { - if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { - return - } - - ipStr, _, err := net.SplitHostPort(addr) - if err != nil { - log.Errorf("Error splitting IP address and port: %v", err) - return - } - - ipAddr, err := netip.ParseAddr(ipStr) - if err != nil { - log.Errorf("Error parsing IP address %s: %v", ipStr, err) - return - } - - prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) - - listenerAddressRemoveHooksMutex.RLock() - defer listenerAddressRemoveHooksMutex.RUnlock() - - for _, hook := range listenerAddressRemoveHooks { - if err := hook(c.ID, prefix); err != nil { - log.Errorf("Error executing listener address remove hook: %v", err) - } - } -} - - -// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality -func WrapPacketConn(conn net.PacketConn) *PacketConn { - return &PacketConn{ - PacketConn: conn, - ID: GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -} diff --git a/util/util_js.go b/util/util_js.go new file mode 100644 index 000000000..8c243cab3 --- /dev/null +++ b/util/util_js.go @@ -0,0 +1,8 @@ +//go:build js + +package util + +// IsAdmin returns false for WASM as there's no admin concept in browser +func IsAdmin() bool { + return false +} diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go new file mode 100644 index 000000000..bd50f51b5 --- /dev/null +++ b/util/wsproxy/client/dialer_js.go @@ -0,0 +1,172 @@ +package client + +import ( + "context" + "fmt" + "net" + "sync" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const dialTimeout = 30 * time.Second + +// websocketConn wraps a JavaScript WebSocket to implement net.Conn +type websocketConn struct { + ws js.Value + remoteAddr string + messages chan []byte + readBuf []byte + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +func (c *websocketConn) Read(b []byte) (int, error) { + c.mu.Lock() + if len(c.readBuf) > 0 { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + c.mu.Unlock() + return n, nil + } + c.mu.Unlock() + + select { + case data := <-c.messages: + n := copy(b, data) + if n < len(data) { + c.mu.Lock() + c.readBuf = data[n:] + c.mu.Unlock() + } + return n, nil + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } +} + +func (c *websocketConn) Write(b []byte) (int, error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + } + + uint8Array := js.Global().Get("Uint8Array").New(len(b)) + js.CopyBytesToJS(uint8Array, b) + c.ws.Call("send", uint8Array) + return len(b), nil +} + +func (c *websocketConn) Close() error { + c.cancel() + c.ws.Call("close") + return nil +} + +func (c *websocketConn) LocalAddr() net.Addr { + return nil +} + +func (c *websocketConn) RemoteAddr() net.Addr { + return stringAddr(c.remoteAddr) +} +func (c *websocketConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// stringAddr is a simple net.Addr that returns a string +type stringAddr string + +func (s stringAddr) Network() string { return "tcp" } +func (s stringAddr) String() string { return string(s) } + +// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithWebSocketDialer(tlsEnabled bool, component string) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + scheme := "wss" + if !tlsEnabled { + scheme = "ws" + } + wsURL := fmt.Sprintf("%s://%s%s%s", scheme, addr, wsproxy.ProxyPath, component) + + ws := js.Global().Get("WebSocket").New(wsURL) + + connCtx, connCancel := context.WithCancel(context.Background()) + conn := &websocketConn{ + ws: ws, + remoteAddr: addr, + messages: make(chan []byte, 100), + ctx: connCtx, + cancel: connCancel, + } + + ws.Set("binaryType", "arraybuffer") + + openCh := make(chan struct{}) + errorCh := make(chan error, 1) + + ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any { + close(openCh) + return nil + })) + + ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any { + select { + case errorCh <- wsproxy.ErrConnectionFailed: + default: + } + return nil + })) + + ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any { + event := args[0] + data := event.Get("data") + + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + + select { + case conn.messages <- bytes: + default: + log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr) + } + return nil + })) + + ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any { + conn.cancel() + return nil + })) + + select { + case <-openCh: + return conn, nil + case err := <-errorCh: + return nil, err + case <-ctx.Done(): + ws.Call("close") + return nil, ctx.Err() + case <-time.After(dialTimeout): + ws.Call("close") + return nil, wsproxy.ErrConnectionTimeout + } + }) +} diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go new file mode 100644 index 000000000..a31c0fbc8 --- /dev/null +++ b/util/wsproxy/constants.go @@ -0,0 +1,20 @@ +package wsproxy + +import "errors" + +// ProxyPath is the base path where the WebSocket proxy is mounted on servers. +const ProxyPath = "/ws-proxy" + +// Component paths that are appended to ProxyPath +const ( + ManagementComponent = "/management" + SignalComponent = "/signal" + FlowComponent = "/flow" +) + +// Common errors +var ( + ErrConnectionTimeout = errors.New("WebSocket connection timeout") + ErrConnectionFailed = errors.New("WebSocket connection failed") + ErrBackendUnavailable = errors.New("backend unavailable") +) diff --git a/util/wsproxy/server/metrics.go b/util/wsproxy/server/metrics.go new file mode 100644 index 000000000..dd3b96dad --- /dev/null +++ b/util/wsproxy/server/metrics.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// MetricsRecorder defines the interface for recording proxy metrics +type MetricsRecorder interface { + // RecordConnection records a new connection + RecordConnection(ctx context.Context) + // RecordDisconnection records a connection closing + RecordDisconnection(ctx context.Context) + // RecordBytesTransferred records bytes transferred in a direction + RecordBytesTransferred(ctx context.Context, direction string, bytes int64) + // RecordError records an error + RecordError(ctx context.Context, errorType string) +} + +// NoOpMetricsRecorder is a no-op implementation that does nothing +type NoOpMetricsRecorder struct{} + +func (n NoOpMetricsRecorder) RecordConnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordDisconnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + // no-op +} +func (n NoOpMetricsRecorder) RecordError(ctx context.Context, errorType string) { + // no-op +} + +// Recorder implements MetricsRecorder using OpenTelemetry +type Recorder struct { + activeConnections metric.Int64UpDownCounter + bytesTransferred metric.Int64Counter + errors metric.Int64Counter +} + +// NewMetricsRecorder creates a new OpenTelemetry-based metrics recorder +func NewMetricsRecorder(meter metric.Meter) (*Recorder, error) { + activeConnections, err := meter.Int64UpDownCounter( + "wsproxy_active_connections", + metric.WithDescription("Number of active WebSocket proxy connections"), + ) + if err != nil { + return nil, err + } + + bytesTransferred, err := meter.Int64Counter( + "wsproxy_bytes_transferred_total", + metric.WithDescription("Total bytes transferred through the proxy"), + ) + if err != nil { + return nil, err + } + + errors, err := meter.Int64Counter( + "wsproxy_errors_total", + metric.WithDescription("Total number of proxy errors"), + ) + if err != nil { + return nil, err + } + + return &Recorder{ + activeConnections: activeConnections, + bytesTransferred: bytesTransferred, + errors: errors, + }, nil +} + +func (o *Recorder) RecordConnection(ctx context.Context) { + o.activeConnections.Add(ctx, 1) +} + +func (o *Recorder) RecordDisconnection(ctx context.Context) { + o.activeConnections.Add(ctx, -1) +} + +func (o *Recorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + o.bytesTransferred.Add(ctx, bytes, metric.WithAttributes( + attribute.String("direction", direction), + )) +} + +func (o *Recorder) RecordError(ctx context.Context, errorType string) { + o.errors.Add(ctx, 1, metric.WithAttributes( + attribute.String("error_type", errorType), + )) +} + +// Option defines functional options for the Proxy +type Option func(*Config) + +// WithMetrics sets a custom metrics recorder +func WithMetrics(recorder MetricsRecorder) Option { + return func(c *Config) { + c.MetricsRecorder = recorder + } +} + +// WithOTelMeter creates and sets an OpenTelemetry metrics recorder +func WithOTelMeter(meter metric.Meter) Option { + return func(c *Config) { + if recorder, err := NewMetricsRecorder(meter); err == nil { + c.MetricsRecorder = recorder + } else { + log.Warnf("Failed to create OTel metrics recorder: %v", err) + } + } +} diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go new file mode 100644 index 000000000..977440a60 --- /dev/null +++ b/util/wsproxy/server/proxy.go @@ -0,0 +1,227 @@ +package server + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/netip" + "sync" + "time" + + "github.com/coder/websocket" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const ( + dialTimeout = 10 * time.Second + bufferSize = 32 * 1024 +) + +// Config contains the configuration for the WebSocket proxy. +type Config struct { + LocalGRPCAddr netip.AddrPort + Path string + MetricsRecorder MetricsRecorder +} + +// Proxy handles WebSocket to TCP proxying for gRPC connections. +type Proxy struct { + config Config + metrics MetricsRecorder +} + +// New creates a new WebSocket proxy instance with optional configuration +func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { + config := Config{ + LocalGRPCAddr: localGRPCAddr, + Path: wsproxy.ProxyPath, + MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op + } + + for _, opt := range opts { + opt(&config) + } + + return &Proxy{ + config: config, + metrics: config.MetricsRecorder, + } +} + +// Handler returns an http.Handler that proxies WebSocket connections to the local gRPC server. +func (p *Proxy) Handler() http.Handler { + return http.HandlerFunc(p.handleWebSocket) +} + +func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + p.metrics.RecordConnection(ctx) + defer p.metrics.RecordDisconnection(ctx) + + log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) + if err != nil { + p.metrics.RecordError(ctx, "websocket_accept_failed") + log.Errorf("WebSocket upgrade failed from %s: %v", r.RemoteAddr, err) + return + } + defer func() { + if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { + log.Debugf("Failed to close WebSocket: %v", err) + } + }() + + log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) + tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) + if err != nil { + p.metrics.RecordError(ctx, "tcp_dial_failed") + log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) + if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { + log.Debugf("Failed to close WebSocket after connection failure: %v", err) + } + return + } + defer func() { + if err := tcpConn.Close(); err != nil { + log.Debugf("Failed to close TCP connection: %v", err) + } + }() + + log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + + p.proxyData(ctx, wsConn, tcpConn) +} + +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { + proxyCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + wg.Add(2) + + go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + log.Tracef("Proxy data transfer completed, both goroutines terminated") + case <-proxyCtx.Done(): + log.Tracef("Proxy data transfer cancelled, forcing connection closure") + + if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { + log.Tracef("Error closing WebSocket during cancellation: %v", err) + } + if err := tcpConn.Close(); err != nil { + log.Tracef("Error closing TCP connection during cancellation: %v", err) + } + + select { + case <-done: + log.Tracef("Goroutines terminated after forced connection closure") + case <-time.After(2 * time.Second): + log.Tracef("Goroutines did not terminate within timeout after connection closure") + } + } +} + +func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + for { + msgType, data, err := wsConn.Read(ctx) + if err != nil { + switch { + case ctx.Err() != nil: + log.Debugf("wsToTCP goroutine terminating due to context cancellation") + case websocket.CloseStatus(err) == websocket.StatusNormalClosure: + log.Debugf("WebSocket closed normally") + default: + p.metrics.RecordError(ctx, "websocket_read_error") + log.Errorf("WebSocket read error: %v", err) + } + return + } + + if msgType != websocket.MessageBinary { + log.Warnf("Unexpected WebSocket message type: %v", msgType) + continue + } + + if ctx.Err() != nil { + log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + return + } + + if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP write deadline: %v", err) + } + + n, err := tcpConn.Write(data) + if err != nil { + p.metrics.RecordError(ctx, "tcp_write_error") + log.Errorf("TCP write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + } +} + +func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + buf := make([]byte, bufferSize) + for { + if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP read deadline: %v", err) + } + n, err := tcpConn.Read(buf) + + if err != nil { + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation") + return + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + continue + } + + if err != io.EOF { + log.Errorf("TCP read error: %v", err) + } + return + } + + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + return + } + + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Errorf("WebSocket write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + } +}