diff --git a/.github/workflows/docs-ack.yml b/.github/workflows/docs-ack.yml index 9116be8c7..f11142a36 100644 --- a/.github/workflows/docs-ack.yml +++ b/.github/workflows/docs-ack.yml @@ -16,19 +16,29 @@ jobs: steps: - name: Read PR body id: body + shell: bash run: | - BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH") - echo "body<> $GITHUB_OUTPUT - echo "$BODY" >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT + set -euo pipefail + BODY_B64=$(jq -r '.pull_request.body // "" | @base64' "$GITHUB_EVENT_PATH") + { + echo "body_b64=$BODY_B64" + } >> "$GITHUB_OUTPUT" - name: Validate checkbox selection id: validate + shell: bash + env: + BODY_B64: ${{ steps.body.outputs.body_b64 }} run: | - body='${{ steps.body.outputs.body }}' + set -euo pipefail + if ! body="$(printf '%s' "$BODY_B64" | base64 -d)"; then + echo "::error::Failed to decode PR body from base64. Data may be corrupted or missing." + exit 1 + fi + + added_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*I added/updated documentation' | wc -l | tr -d '[:space:]' || true) + noneed_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*Documentation is \*\*not needed\*\*' | wc -l | tr -d '[:space:]' || true) - added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ') - noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ') if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then echo "::error::Choose exactly one: either 'docs added' OR 'not needed'." @@ -41,30 +51,35 @@ jobs: fi if [ "$added_checked" -eq 1 ]; then - echo "mode=added" >> $GITHUB_OUTPUT + echo "mode=added" >> "$GITHUB_OUTPUT" else - echo "mode=noneed" >> $GITHUB_OUTPUT + echo "mode=noneed" >> "$GITHUB_OUTPUT" fi - name: Extract docs PR URL (when 'docs added') if: steps.validate.outputs.mode == 'added' id: extract + shell: bash + env: + BODY_B64: ${{ steps.body.outputs.body_b64 }} run: | - body='${{ steps.body.outputs.body }}' + set -euo pipefail + body="$(printf '%s' "$BODY_B64" | base64 -d)" # Strictly require HTTPS and that it's a PR in netbirdio/docs - # Examples accepted: - # https://github.com/netbirdio/docs/pull/1234 - url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true) + # e.g., https://github.com/netbirdio/docs/pull/1234 + url="$(printf '%s' "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)" - if [ -z "$url" ]; then + if [ -z "${url:-}" ]; then echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)." exit 1 fi - pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#') - echo "url=$url" >> $GITHUB_OUTPUT - echo "pr_number=$pr_number" >> $GITHUB_OUTPUT + pr_number="$(printf '%s' "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')" + { + echo "url=$url" + echo "pr_number=$pr_number" + } >> "$GITHUB_OUTPUT" - name: Verify docs PR exists (and is open or merged) if: steps.validate.outputs.mode == 'added' diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 0013833c4..ba36c013b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go @@ -382,6 +382,32 @@ jobs: store: [ 'sqlite', 'postgres' ] runs-on: ubuntu-22.04 steps: + - name: Create Docker network + run: docker network create promnet + + - name: Start Prometheus Pushgateway + run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway + + - name: Start Prometheus (for Pushgateway forwarding) + run: | + echo ' + global: + scrape_interval: 15s + scrape_configs: + - job_name: "pushgateway" + static_configs: + - targets: ["pushgateway:9091"] + remote_write: + - url: ${{ secrets.GRAFANA_URL }} + basic_auth: + username: ${{ secrets.GRAFANA_USER }} + password: ${{ secrets.GRAFANA_API_KEY }} + ' > prometheus.yml + + docker run -d --name prometheus --network promnet \ + -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ + -p 9090:9090 \ + prom/prometheus - name: Install Go uses: actions/setup-go@v5 with: @@ -428,9 +454,10 @@ jobs: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \ CI=true \ + GIT_BRANCH=${{ github.ref_name }} \ go test -tags devcert -run=^$ -bench=. \ - -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/... ./shared/management/... + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ + -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http) api_benchmark: name: "Management / Benchmark (API)" @@ -521,7 +548,7 @@ jobs: -run=^$ \ -bench=. \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ - -timeout 20m ./management/... ./shared/management/... + -timeout 20m ./management/server/http/... api_integration_test: name: "Management / Integration" @@ -571,4 +598,4 @@ jobs: CI=true \ go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/... ./shared/management/... \ No newline at end of file + -timeout 20m ./management/server/http/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index d9ff0a84b..2083c0721 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -63,7 +63,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV - name: test run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7be52259b..e9741f541 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.22" + SIGN_PIPE_VER: "v0.0.23" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" diff --git a/README.md b/README.md index ea7655869..2c5ee2ab6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +


@@ -52,7 +53,7 @@ ### Open Source Network Security in a Single Platform -centralized-network-management 1 +https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) diff --git a/client/Dockerfile b/client/Dockerfile index e19a09909..b2f627409 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,7 +18,7 @@ ENV \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_LOGIN_TIMEOUT="5" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index 678f5d9d5..218817e62 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,6 +4,7 @@ package android import ( "context" + "os" "slices" "sync" @@ -18,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile @@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } + } +} diff --git a/client/android/env_list.go b/client/android/env_list.go new file mode 100644 index 000000000..04122300a --- /dev/null +++ b/client/android/env_list.go @@ -0,0 +1,32 @@ +package android + +import "github.com/netbirdio/netbird/client/internal/peer" + +var ( + // EnvKeyNBForceRelay Exported for Android java client + EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay +) + +// EnvList wraps a Go map for export to Java +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} diff --git a/client/android/login.go b/client/android/login.go index d8ac645e2..0df78dbc3 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -33,6 +33,7 @@ type ErrListener interface { // the backend want to show an url for the user type URLOpener interface { Open(string) + OnLoginSuccess() } // Auth can register or login new client @@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error { err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) + + if err == nil { + go urlOpener.OnLoginSuccess() + } + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return nil } diff --git a/client/cmd/down.go b/client/cmd/down.go index 3ce51c678..17c152d22 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -27,7 +27,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/cmd/login.go b/client/cmd/login.go index 92de6abdb..3ac211805 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, } // update host's static platform and system information - system.UpdateStaticInfo() + system.UpdateStaticInfoAsync() configFilePath, err := activeProf.FilePath() if err != nil { diff --git a/client/cmd/root.go b/client/cmd/root.go index 8aa0d7c89..11e5228f1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -39,6 +39,7 @@ const ( extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" enableLazyConnectionFlag = "enable-lazy-connection" + mtuFlag = "mtu" ) var ( @@ -72,6 +73,7 @@ var ( anonymizeFlag bool dnsRouteInterval time.Duration lazyConnEnabled bool + mtu uint16 profilesDisabled bool updateSettingsDisabled bool @@ -229,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*3) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() return grpc.DialContext( diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go index 844eea853..ce95786dd 100644 --- a/client/cmd/root_test.go +++ b/client/cmd/root_test.go @@ -54,6 +54,7 @@ func TestSetFlagsFromEnvVars(t *testing.T) { cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") + cmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface") t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") t.Setenv("NB_INTERFACE_NAME", "test-name") diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 50fb35d5e..0545ce6b7 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error { log.Info("starting NetBird service") //nolint // Collect static system and platform information - system.UpdateStaticInfo() + system.UpdateStaticInfoAsync() // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer() diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 42cca1a9b..729b191c3 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -9,29 +9,26 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" + client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" + mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" - - "github.com/netbirdio/netbird/util" - - "google.golang.org/grpc" - - "github.com/netbirdio/management-integrations/integrations" - - clientProto "github.com/netbirdio/netbird/client/proto" - client "github.com/netbirdio/netbird/client/server" - mgmt "github.com/netbirdio/netbird/management/server" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto" sig "github.com/netbirdio/netbird/signal/server" + "github.com/netbirdio/netbird/util" ) func startTestingServices(t *testing.T) string { @@ -91,15 +88,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { return nil, nil } - iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - require.NoError(t, err) ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - settingsMockManager := settings.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl) + peersmanager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) + + iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() settingsMockManager.EXPECT(). diff --git a/client/cmd/up.go b/client/cmd/up.go index 7cc342fe0..1b751aa55 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -63,6 +63,7 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") + upCmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface") upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, @@ -230,7 +231,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{}) + status, err := client.Status(ctx, &proto.StatusRequest{ + WaitForReady: func() *bool { b := true; return &b }(), + }) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } @@ -358,6 +361,11 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro req.WireguardPort = &p } + if cmd.Flag(mtuFlag).Changed { + m := int64(mtu) + req.Mtu = &m + } + if cmd.Flag(networkMonitorFlag).Changed { req.NetworkMonitor = &networkMonitor } @@ -437,6 +445,13 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.WireguardPort = &p } + if cmd.Flag(mtuFlag).Changed { + if err := iface.ValidateMTU(mtu); err != nil { + return nil, err + } + ic.MTU = &mtu + } + if cmd.Flag(networkMonitorFlag).Changed { ic.NetworkMonitor = &networkMonitor } @@ -534,6 +549,14 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.WireguardPort = &wp } + if cmd.Flag(mtuFlag).Changed { + if err := iface.ValidateMTU(mtu); err != nil { + return nil, err + } + m := int64(mtu) + loginRequest.Mtu = &m + } + if cmd.Flag(networkMonitorFlag).Changed { loginRequest.NetworkMonitor = &networkMonitor } diff --git a/client/embed/embed.go b/client/embed/embed.go index 79f5f0e43..c62efc960 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -137,7 +137,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}, 1) + run := make(chan struct{}) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7b90000a8..ed8a7403b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -12,7 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1e44c7a4d..081991235 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // constants needed to manage and create iptable rules diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index e9eeff863..3490c5dad 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,7 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func isIptablesSupported() bool { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 52979d257..9ff5b8c92 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index f8fed4d80..e918d0524 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -22,7 +22,7 @@ import ( nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/util/grpc/dialer.go b/client/grpc/dialer.go similarity index 91% rename from util/grpc/dialer.go rename to client/grpc/dialer.go index f6d6d2f04..69e3f088c 100644 --- a/util/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -20,8 +20,9 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" ) func WithCustomDialer() grpc.DialOption { @@ -57,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -71,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { })) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() conn, err := grpc.DialContext( diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go index 89bddf12c..32b07c330 100644 --- a/client/iface/bind/control.go +++ b/client/iface/bind/control.go @@ -3,7 +3,7 @@ package bind import ( wireguard "golang.zx2c4.com/wireguard/conn" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go index 1926ff88f..caa92f05d 100644 --- a/client/iface/bind/endpoint.go +++ b/client/iface/bind/endpoint.go @@ -1,5 +1,17 @@ package bind -import wgConn "golang.zx2c4.com/wireguard/conn" +import ( + "net" + + wgConn "golang.zx2c4.com/wireguard/conn" +) type Endpoint = wgConn.StdNetEndpoint + +func EndpointToUDPAddr(e Endpoint) *net.UDPAddr { + return &net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + } +} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 41f4aec6d..ef630b9d0 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,7 @@ package bind import ( + "context" "encoding/binary" "fmt" "net" @@ -8,15 +9,16 @@ import ( "runtime" "sync" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type RecvMessage struct { @@ -41,10 +43,10 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - RecvChan chan RecvMessage + recvChan chan RecvMessage transportNet transport.Net - filterFn FilterFn + filterFn udpmux.FilterFn endpoints map[netip.Addr]net.Conn endpointsMu sync.Mutex // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a @@ -54,21 +56,23 @@ type ICEBind struct { closed bool muUDPMux sync.Mutex - udpMux *UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault address wgaddr.Address + mtu uint16 activityRecorder *ActivityRecorder } -func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - RecvChan: make(chan RecvMessage, 1), + recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, endpoints: make(map[netip.Addr]net.Conn), closedChan: make(chan struct{}), closed: true, + mtu: mtu, address: address, activityRecorder: NewActivityRecorder(), } @@ -80,6 +84,10 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad return ib } +func (s *ICEBind) MTU() uint16 { + return s.mtu +} + func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { s.closed = false s.closedChanMu.Lock() @@ -109,7 +117,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder { } // GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { +func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() if s.udpMux == nil { @@ -148,16 +156,25 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } +func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { + select { + case <-ctx.Done(): + return + case b.recvChan <- msg: + } +} + func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, WGAddress: s.address, + MTU: s.mtu, }, ) return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { @@ -263,7 +280,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo select { case <-c.closedChan: return 0, net.ErrClosed - case msg, ok := <-c.RecvChan: + case msg, ok := <-c.recvChan: if !ok { return 0, net.ErrClosed } diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go deleted file mode 100644 index 15e26d02f..000000000 --- a/client/iface/bind/udp_mux_ios.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build ios - -package bind - -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { - // iOS doesn't support nbnet hooks, so this is a no-op -} \ No newline at end of file diff --git a/client/iface/bufsize/bufsize.go b/client/iface/bufsize/bufsize.go new file mode 100644 index 000000000..0d2afb77d --- /dev/null +++ b/client/iface/bufsize/bufsize.go @@ -0,0 +1,9 @@ +package bufsize + +const ( + // WGBufferOverhead represents the additional buffer space needed beyond MTU + // for WireGuard packet encapsulation (WG header + UDP + IP + safety margin) + // Original hardcoded buffers were 1500, default MTU is 1280, so overhead = 220 + // TODO: Calculate this properly based on actual protocol overhead instead of using hardcoded difference + WGBufferOverhead = 220 +) diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 171458e38..f744e0127 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/monotime" - nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) { if err != nil { return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) } + + // If sec is 0 (Unix epoch), return zero time instead + // This indicates no handshake has occurred + if sec == 0 { + return time.Time{}, nil + } + return time.Unix(sec, 0), nil } @@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) { } func getFwmark() int { - if nbnet.AdvancedRouting() { + if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { return nbnet.ControlPlaneMark } return 0 diff --git a/client/iface/device.go b/client/iface/device.go index 81f2e0f47..921f0ea98 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -7,16 +7,17 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address + MTU() uint16 DeviceName() string Close() error FilteredDevice() *device.FilteredDevice diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 4fe6e466b..a731684cc 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -21,7 +22,7 @@ type WGTunDevice struct { address wgaddr.Address port int key string - mtu int + mtu uint16 iceBind *bind.ICEBind tunAdapter TunAdapter disableDNS bool @@ -29,11 +30,11 @@ type WGTunDevice struct { name string device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { +func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ address: address, port: port, @@ -58,7 +59,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err @@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string } return t.configurer, nil } -func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -137,6 +138,10 @@ func (t *WGTunDevice) WgAddress() wgaddr.Address { return t.address } +func (t *WGTunDevice) MTU() uint16 { + return t.mtu +} + func (t *WGTunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 81de0e360..390efe088 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -21,16 +22,16 @@ type TunDevice struct { address wgaddr.Address port int key string - mtu int + mtu uint16 iceBind *bind.ICEBind device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, @@ -42,7 +43,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu } func (t *TunDevice) Create() (WGConfigurer, error) { - tunDevice, err := tun.CreateTUN(t.name, t.mtu) + tunDevice, err := tun.CreateTUN(t.name, int(t.mtu)) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } @@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -111,6 +112,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } +func (t *TunDevice) MTU() uint16 { + return t.mtu +} + func (t *TunDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 4613762c3..96e4c8bcf 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -22,21 +23,23 @@ type TunDevice struct { address wgaddr.Address port int key string + mtu uint16 iceBind *bind.ICEBind tunFd int device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunFd int) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, + mtu: mtu, iceBind: iceBind, tunFd: tunFd, } @@ -81,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -125,6 +128,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } +func (t *TunDevice) MTU() uint16 { + return t.mtu +} + func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { // todo implement return nil diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 7136be0bc..cdac43a53 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -12,11 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/sharedsock" - nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -24,19 +24,19 @@ type TunKernelDevice struct { address wgaddr.Address wgPort int key string - mtu int + mtu uint16 ctx context.Context ctxCancel context.CancelFunc transportNet transport.Net link *wgLink udpMuxConn net.PacketConn - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault - filterFn bind.FilterFn + filterFn udpmux.FilterFn } -func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { +func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { ctx, cancel := context.WithCancel(context.Background()) return &TunKernelDevice{ ctx: ctx, @@ -66,7 +66,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { // TODO: do a MTU discovery log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) - if err := link.setMTU(t.mtu); err != nil { + if err := link.setMTU(int(t.mtu)); err != nil { return nil, fmt.Errorf("set mtu: %w", err) } @@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { return configurer, nil } -func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -96,23 +96,19 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } - rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter()) + rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter(), t.mtu) if err != nil { return nil, err } - var udpConn net.PacketConn = rawSock - if !nbnet.AdvancedRouting() { - udpConn = nbnet.WrapPacketConn(rawSock) - } - - bindParams := bind.UniversalUDPMuxParams{ - UDPConn: udpConn, + bindParams := udpmux.UniversalUDPMuxParams{ + UDPConn: nbnet.WrapPacketConn(rawSock), Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, + MTU: t.mtu, } - mux := bind.NewUniversalUDPMuxDefault(bindParams) + mux := udpmux.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) t.udpMuxConn = rawSock t.udpMux = mux @@ -158,6 +154,10 @@ func (t *TunKernelDevice) WgAddress() wgaddr.Address { return t.address } +func (t *TunKernelDevice) MTU() uint16 { + return t.mtu +} + func (t *TunKernelDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index fc3cb0215..a6ef47027 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,6 +1,3 @@ -//go:build !android -// +build !android - package device import ( @@ -13,8 +10,9 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type TunNetstackDevice struct { @@ -22,20 +20,20 @@ type TunNetstackDevice struct { address wgaddr.Address port int key string - mtu int + mtu uint16 listenAddress string iceBind *bind.ICEBind device *device.Device filteredDevice *FilteredDevice nsTun *nbnetstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer net *netstack.Net } -func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -47,7 +45,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri } } -func (t *TunNetstackDevice) Create() (WGConfigurer, error) { +func (t *TunNetstackDevice) create() (WGConfigurer, error) { log.Info("create nbnetstack tun interface") // TODO: get from service listener runtime IP @@ -57,7 +55,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) { } log.Debugf("netstack using address: %s", t.address.IP) - t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) + t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu)) log.Debugf("netstack using dns address: %s", dnsAddr) tunIface, net, err := t.nsTun.Create() if err != nil { @@ -83,7 +81,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -125,6 +123,10 @@ func (t *TunNetstackDevice) WgAddress() wgaddr.Address { return t.address } +func (t *TunNetstackDevice) MTU() uint16 { + return t.mtu +} + func (t *TunNetstackDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_netstack_android.go b/client/iface/device/device_netstack_android.go new file mode 100644 index 000000000..45ae8ba7d --- /dev/null +++ b/client/iface/device/device_netstack_android.go @@ -0,0 +1,7 @@ +//go:build android + +package device + +func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { + return t.create() +} diff --git a/client/iface/device/device_netstack_generic.go b/client/iface/device/device_netstack_generic.go new file mode 100644 index 000000000..4b3974f26 --- /dev/null +++ b/client/iface/device/device_netstack_generic.go @@ -0,0 +1,7 @@ +//go:build !android + +package device + +func (t *TunNetstackDevice) Create() (WGConfigurer, error) { + return t.create() +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index e781f6004..4cdd70a32 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -20,16 +21,16 @@ type USPDevice struct { address wgaddr.Address port int key string - mtu int + mtu uint16 iceBind *bind.ICEBind device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } -func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { +func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice { log.Infof("using userspace bind mode") return &USPDevice{ @@ -44,9 +45,9 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu func (t *USPDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") - tunIface, err := tun.CreateTUN(t.name, t.mtu) + tunIface, err := tun.CreateTUN(t.name, int(t.mtu)) if err != nil { - log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) + log.Debugf("failed to create tun interface (%s, %d): %s", t.name, int(t.mtu), err) return nil, fmt.Errorf("error creating tun device: %s", err) } t.filteredDevice = newDeviceFilter(tunIface) @@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -118,6 +119,10 @@ func (t *USPDevice) WgAddress() wgaddr.Address { return t.address } +func (t *USPDevice) MTU() uint16 { + return t.mtu +} + func (t *USPDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 0316c4b8d..f1023bc0a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -23,17 +24,17 @@ type TunDevice struct { address wgaddr.Address port int key string - mtu int + mtu uint16 iceBind *bind.ICEBind device *device.Device nativeTunDevice *tun.NativeTun filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, @@ -59,7 +60,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return nil, err } log.Info("create tun interface") - tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu) + tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, int(t.mtu)) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } @@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -144,6 +145,10 @@ func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } +func (t *TunDevice) MTU() uint16 { + return t.mtu +} + func (t *TunDevice) DeviceName() string { return t.name } diff --git a/client/iface/device_android.go b/client/iface/device_android.go index a1e246fc5..4649b8b97 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -5,16 +5,17 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address + MTU() uint16 DeviceName() string Close() error FilteredDevice() *device.FilteredDevice diff --git a/client/iface/iface.go b/client/iface/iface.go index 0e41f8e64..609572561 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -16,9 +16,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -26,6 +26,8 @@ import ( const ( DefaultMTU = 1280 + MinMTU = 576 + MaxMTU = 8192 DefaultWgPort = 51820 WgInterfaceDefault = configurer.WgInterfaceDefault ) @@ -35,6 +37,17 @@ var ( ErrIfaceNotFound = fmt.Errorf("wireguard interface not found") ) +// ValidateMTU validates that MTU is within acceptable range +func ValidateMTU(mtu uint16) error { + if mtu < MinMTU { + return fmt.Errorf("MTU %d below minimum (%d bytes)", mtu, MinMTU) + } + if mtu > MaxMTU { + return fmt.Errorf("MTU %d exceeds maximum supported size (%d bytes)", mtu, MaxMTU) + } + return nil +} + type wgProxyFactory interface { GetProxy() wgproxy.Proxy Free() error @@ -45,10 +58,10 @@ type WGIFaceOpts struct { Address string WGPort int WGPrivKey string - MTU int + MTU uint16 MobileArgs *device.MobileIFaceArguments TransportNet transport.Net - FilterFn bind.FilterFn + FilterFn udpmux.FilterFn DisableDNS bool } @@ -82,6 +95,10 @@ func (w *WGIface) Address() wgaddr.Address { return w.tun.WgAddress() } +func (w *WGIface) MTU() uint16 { + return w.tun.MTU() +} + // ToInterface returns the net.Interface for the Wireguard interface func (r *WGIface) ToInterface() *net.Interface { name := r.tun.DeviceName() @@ -97,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface { // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index c8babea32..26952f48d 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -3,6 +3,7 @@ package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) @@ -14,7 +15,16 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + + if netstack.IsEnabled() { + wgIFace := &WGIface{ + userspaceBind: true, + tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil + } wgIFace := &WGIface{ userspaceBind: true, diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index 93fd7fd5c..7dd74d571 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -17,7 +17,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) var tun WGTunDevice if netstack.IsEnabled() { diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go new file mode 100644 index 000000000..86ed14ce1 --- /dev/null +++ b/client/iface/iface_new_freebsd.go @@ -0,0 +1,41 @@ +//go:build freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 317ee0f46..06ccf0be1 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -16,10 +16,10 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace := &WGIface{ - tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), + tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, wgProxyFactory: wgproxy.NewUSPFactory(iceBind), } diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_linux.go similarity index 90% rename from client/iface/iface_new_unix.go rename to client/iface/iface_new_linux.go index 23ee7236f..77fd30fae 100644 --- a/client/iface/iface_new_unix.go +++ b/client/iface/iface_new_linux.go @@ -1,4 +1,4 @@ -//go:build (linux && !android) || freebsd +//go:build linux && !android package iface @@ -22,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{} if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) @@ -31,11 +31,11 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { if device.WireGuardModuleIsLoaded() { wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) - wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort) + wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU) return wgIFace, nil } if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 413062940..349c5b33b 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -14,7 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { if err != nil { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) var tun WGTunDevice if netstack.IsEnabled() { diff --git a/client/iface/bind/udp_muxed_conn.go b/client/iface/udpmux/conn.go similarity index 95% rename from client/iface/bind/udp_muxed_conn.go rename to client/iface/udpmux/conn.go index 7cacf1c31..3aa40caeb 100644 --- a/client/iface/bind/udp_muxed_conn.go +++ b/client/iface/udpmux/conn.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements @@ -16,11 +16,12 @@ import ( ) type udpMuxedConnParams struct { - Mux *UDPMuxDefault - AddrPool *sync.Pool - Key string - LocalAddr net.Addr - Logger logging.LeveledLogger + Mux *SingleSocketUDPMux + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger + CandidateID string } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag @@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error { return err } +func (c *udpMuxedConn) GetCandidateID() string { + return c.params.CandidateID +} + func (c *udpMuxedConn) isClosed() bool { select { case <-c.closedChan: diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go new file mode 100644 index 000000000..27e5e43bc --- /dev/null +++ b/client/iface/udpmux/doc.go @@ -0,0 +1,64 @@ +// Package udpmux provides a custom implementation of a UDP multiplexer +// that allows multiple logical ICE connections to share a single underlying +// UDP socket. This is based on Pion's ICE library, with modifications for +// NetBird's requirements. +// +// # Background +// +// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity +// Establishment) is responsible for discovering candidate network paths +// and maintaining connectivity between peers. Each ICE connection +// normally requires a dedicated UDP socket. However, using one socket +// per candidate can be inefficient and difficult to manage. +// +// This package introduces SingleSocketUDPMux, which allows multiple ICE +// candidate connections (muxed connections) to share a single UDP socket. +// It handles demultiplexing of packets based on ICE ufrag values, STUN +// attributes, and candidate IDs. +// +// # Usage +// +// The typical flow is: +// +// 1. Create a UDP socket (net.PacketConn). +// 2. Construct Params with the socket and optional logger/net stack. +// 3. Call NewSingleSocketUDPMux(params). +// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID) +// to obtain a logical PacketConn. +// 5. Use the returned PacketConn just like a normal UDP connection. +// +// # STUN Message Routing Logic +// +// When a STUN packet arrives, the mux decides which connection should +// receive it using this routing logic: +// +// Primary Routing: Candidate Pair ID +// - Extract the candidate pair ID from the STUN message using +// ice.CandidatePairIDFromSTUN(msg) +// - The target candidate is the locally generated candidate that +// corresponds to the connection that should handle this STUN message +// - If found, use the target candidate ID to lookup the specific +// connection in candidateConnMap +// - Route the message directly to that connection +// +// Fallback Routing: Broadcasting +// When candidate pair ID is not available or lookup fails: +// - Collect connections from addressMap based on source address +// - Find connection using username attribute (ufrag) from STUN message +// - Remove duplicate connections from the list +// - Send the STUN message to all collected connections +// +// # Peer Reflexive Candidate Discovery +// +// When a remote peer sends a STUN message from an unknown source address +// (from a candidate that has not been exchanged via signal), the ICE +// library will: +// - Generate a new peer reflexive candidate for this source address +// - Extract or assign a candidate ID based on the STUN message attributes +// - Create a mapping between the new peer reflexive candidate ID and +// the appropriate local connection +// +// This discovery mechanism ensures that STUN messages from newly discovered +// peer reflexive candidates can be properly routed to the correct local +// connection without requiring fallback broadcasting. +package udpmux diff --git a/client/iface/bind/udp_mux.go b/client/iface/udpmux/mux.go similarity index 65% rename from client/iface/bind/udp_mux.go rename to client/iface/udpmux/mux.go index 29e5d7937..319724926 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/udpmux/mux.go @@ -1,4 +1,4 @@ -package bind +package udpmux import ( "fmt" @@ -8,9 +8,9 @@ import ( "strings" "sync" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/logging" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" @@ -22,9 +22,9 @@ import ( const receiveMTU = 8192 -// UDPMuxDefault is an implementation of the interface -type UDPMuxDefault struct { - params UDPMuxParams +// SingleSocketUDPMux is an implementation of the interface +type SingleSocketUDPMux struct { + params Params closedChan chan struct{} closeOnce sync.Once @@ -32,6 +32,9 @@ type UDPMuxDefault struct { // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn + // candidateConnMap maps local candidate IDs to their corresponding connection. + candidateConnMap map[string]*udpMuxedConn + addressMapMu sync.RWMutex addressMap map[string][]*udpMuxedConn @@ -46,8 +49,8 @@ type UDPMuxDefault struct { const maxAddrSize = 512 -// UDPMuxParams are parameters for UDPMux. -type UDPMuxParams struct { +// Params are parameters for UDPMux. +type Params struct { Logger logging.LeveledLogger UDPConn net.PacketConn @@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool { return true } -// NewUDPMuxDefault creates an implementation of UDPMux -func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { +// NewSingleSocketUDPMux creates an implementation of UDPMux +func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { if params.Logger == nil { params.Logger = getLogger() } - mux := &UDPMuxDefault{ - addressMap: map[string][]*udpMuxedConn{}, - params: params, - connsIPv4: make(map[string]*udpMuxedConn), - connsIPv6: make(map[string]*udpMuxedConn), - closedChan: make(chan struct{}, 1), + mux := &SingleSocketUDPMux{ + addressMap: map[string][]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + candidateConnMap: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address @@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return mux } -func (m *UDPMuxDefault) updateLocalAddresses() { +func (m *SingleSocketUDPMux) updateLocalAddresses() { var localAddrsForUnspecified []net.Addr if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) } else if ok && addr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + // with SingleSocketUDPMux, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { @@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() { m.mu.Unlock() } -// LocalAddr returns the listening address of this UDPMuxDefault -func (m *UDPMuxDefault) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this SingleSocketUDPMux +func (m *SingleSocketUDPMux) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on -func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { +func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { m.updateLocalAddresses() m.mu.Lock() @@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network address // creates the connection if an existing one can't be found -func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { +func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { // don't check addr for mux using unspecified address m.mu.Lock() lenLocalAddrs := len(m.localAddrsForUnspecified) @@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return conn, nil } - c := m.createMuxedConn(ufrag) + c := m.createMuxedConn(ufrag, candidateID) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() + m.candidateConnMap[candidateID] = c + if isIPv6 { m.connsIPv6[ufrag] = c } else { @@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er } // RemoveConnByUfrag stops and removes the muxed packet connection -func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { +func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock @@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } m.mu.Unlock() @@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { } // IsClosed returns true if the mux had been closed -func (m *UDPMuxDefault) IsClosed() bool { +func (m *SingleSocketUDPMux) IsClosed() bool { select { case <-m.closedChan: return true @@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool { } // Close the mux, no further connections could be created -func (m *UDPMuxDefault) Close() error { +func (m *SingleSocketUDPMux) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() @@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error { return err } -func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { +func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } @@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } -func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { +func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ - Mux: m, - Key: key, - AddrPool: m.pool, - LocalAddr: m.LocalAddr(), - Logger: m.params.Logger, + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + CandidateID: candidateID, }) return c } // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library -func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { - +func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { remoteAddr, ok := addr.(*net.UDPAddr) if !ok { return fmt.Errorf("underlying PacketConn did not return a UDPAddr") } - // If we have already seen this address dispatch to the appropriate destination - // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one - // muxed connection - one for the SRFLX candidate and the other one for the HOST one. - // We will then forward STUN packets to each of these connections. - m.addressMapMu.RLock() + // Try to route to specific candidate connection first + if conn := m.findCandidateConnection(msg); conn != nil { + return conn.writePacket(msg.Raw, remoteAddr) + } + + // Fallback: route to all possible connections + return m.forwardToAllConnections(msg, addr, remoteAddr) +} + +// findCandidateConnection attempts to find the specific connection for a STUN message +func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn { + candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg) + if err != nil { + return nil + } else if !ok { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()] + if !exists { + return nil + } + return conn +} + +// forwardToAllConnections forwards STUN message to all relevant connections +func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error { var destinationConnList []*udpMuxedConn + + // Add connections from address map + m.addressMapMu.RLock() if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } m.addressMapMu.RUnlock() - var isIPv6 bool - if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { - isIPv6 = true + if conn, ok := m.findConnectionByUsername(msg, addr); ok { + // If we have already seen this address dispatch to the appropriate destination + // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one + // muxed connection - one for the SRFLX candidate and the other one for the HOST one. + // We will then forward STUN packets to each of these connections. + if !m.connectionExists(conn, destinationConnList) { + destinationConnList = append(destinationConnList, conn) + } } - // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. - // However, we can take a username attribute from the STUN message which contains ufrag. - // We can use ufrag to identify the destination conn to route packet to. - attr, stunAttrErr := msg.Get(stun.AttrUsername) - if stunAttrErr == nil { - ufrag := strings.Split(string(attr), ":")[0] - - m.mu.Lock() - destinationConn := m.connsIPv4[ufrag] - if isIPv6 { - destinationConn = m.connsIPv6[ufrag] - } - - if destinationConn != nil { - exists := false - for _, conn := range destinationConnList { - if conn.params.Key == destinationConn.params.Key { - exists = true - break - } - } - if !exists { - destinationConnList = append(destinationConnList, destinationConn) - } - } - m.mu.Unlock() - } - - // Forward STUN packets to each destination connections even thought the STUN packet might not belong there. - // It will be discarded by the further ICE candidate logic if so. + // Forward to all found connections for _, conn := range destinationConnList { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { log.Errorf("could not write packet: %v", err) } } - return nil } -func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { +// findConnectionByUsername finds connection using username attribute from STUN message +func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) { + attr, err := msg.Get(stun.AttrUsername) + if err != nil { + return nil, false + } + + ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := isIPv6Address(addr) + + m.mu.Lock() + defer m.mu.Unlock() + + return m.getConn(ufrag, isIPv6) +} + +// connectionExists checks if a connection already exists in the list +func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool { + for _, conn := range conns { + if conn.params.Key == target.params.Key { + return true + } + } + return false +} + +func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { @@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o return } +func isIPv6Address(addr net.Addr) bool { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + return udpAddr.IP.To4() == nil + } + return false +} + type bufferHolder struct { buf []byte } diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/udpmux/mux_generic.go similarity index 76% rename from client/iface/bind/udp_mux_generic.go rename to client/iface/udpmux/mux_generic.go index 63f786d2b..29fc2d834 100644 --- a/client/iface/bind/udp_mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -1,12 +1,12 @@ //go:build !ios -package bind +package udpmux import ( - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { conn.RemoveAddress(addr) diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go new file mode 100644 index 000000000..4cf211d8f --- /dev/null +++ b/client/iface/udpmux/mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package udpmux + +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/udpmux/universal.go similarity index 95% rename from client/iface/bind/udp_mux_universal.go rename to client/iface/udpmux/universal.go index b755a7827..43bfedaaa 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/udpmux/universal.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. @@ -15,9 +15,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/pion/logging" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" + "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -28,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { - *UDPMuxDefault + *SingleSocketUDPMux params UniversalUDPMuxParams // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents @@ -44,6 +45,7 @@ type UniversalUDPMuxParams struct { Net transport.Net FilterFn FilterFn WGAddress wgaddr.Address + MTU uint16 } // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux @@ -70,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - udpMuxParams := UDPMuxParams{ + udpMuxParams := Params{ Logger: params.Logger, UDPConn: m.params.UDPConn, Net: m.params.Net, } - m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) return m } @@ -84,7 +86,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef // just ignore other packets printing an warning message. // It is a blocking method, consider running in a go routine. func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { - buf := make([]byte, 1500) + buf := make([]byte, m.params.MTU+bufsize.WGBufferOverhead) for { select { case <-ctx.Done(): @@ -209,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { - return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { + return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) } // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. @@ -231,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A } return nil } - return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) + return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index f68e84810..dbc694e91 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -12,31 +12,41 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +type IceBind interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) + Recv(ctx context.Context, msg bind.RecvMessage) + MTU() uint16 +} + type ProxyBind struct { - Bind *bind.ICEBind + bind IceBind - fakeNetIP *netip.AddrPort - wgBindEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address + wgRelayedEndpoint *bind.Endpoint + wgCurrentUsed *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyBind(bind *bind.ICEBind) *ProxyBind { +func NewProxyBind(bind IceBind) *ProxyBind { p := &ProxyBind{ - Bind: bind, + bind: bind, closeListener: listener.NewCloseListener(), + pausedCond: sync.NewCond(&sync.Mutex{}), } return p @@ -45,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind { // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. +// +// Parameters: +// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages +// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address +// - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - - p.fakeNetIP = fakeNetIP - p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} + p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } + func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return &net.UDPAddr{ - IP: p.fakeNetIP.Addr().AsSlice(), - Port: int(p.fakeNetIP.Port()), - Zone: p.fakeNetIP.Addr().Zone(), - } + return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { @@ -75,17 +85,21 @@ func (p *ProxyBind) Work() { return } - p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) + p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { @@ -93,9 +107,25 @@ func (p *ProxyBind) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgCurrentUsed = addrToEndpoint(endpoint) + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} } func (p *ProxyBind) CloseConn() error { @@ -106,6 +136,10 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { + if p.remoteConn == nil { + return nil + } + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -119,7 +153,12 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -135,7 +174,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, 1500) + buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -146,18 +185,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } msg := bind.RecvMessage{ - Endpoint: p.wgBindEndpoint, + Endpoint: p.wgCurrentUsed, Buffer: buf[:n], } - p.Bind.RecvChan <- msg - p.pausedMu.Unlock() + p.bind.Recv(ctx, msg) + p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index e21fc35d4..858143091 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -6,9 +6,7 @@ import ( "context" "fmt" "net" - "os" "sync" - "syscall" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -17,18 +15,25 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/bufsize" + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( loopbackAddr = "127.0.0.1" ) +var ( + localHostNetIP = net.ParseIP("127.0.0.1") +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int + mtu uint16 ebpfManager ebpfMgr.Manager turnConnStore map[uint16]net.Conn @@ -43,10 +48,11 @@ type WGEBPFProxy struct { } // NewWGEBPFProxy create new WGEBPFProxy instance -func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { +func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy { log.Debugf("instantiate ebpf proxy") wgProxy := &WGEBPFProxy{ localWGListenPort: wgPort, + mtu: mtu, ebpfManager: ebpf.GetEbpfManagerInstance(), turnConnStore: make(map[uint16]net.Conn), } @@ -61,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = p.prepareSenderRawSocket() + p.rawConn, err = rawsocket.PrepareSenderRawSocket() if err != nil { return err } @@ -138,7 +144,7 @@ func (p *WGEBPFProxy) Free() error { // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { - buf := make([]byte, 1500) + buf := make([]byte, p.mtu+bufsize.WGBufferOverhead) for p.ctx.Err() == nil { if err := p.readAndForwardPacket(buf); err != nil { if p.ctx.Err() != nil { @@ -211,57 +217,17 @@ generatePort: return p.lastUsedPort, nil } -func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { - // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) - if err != nil { - return nil, fmt.Errorf("creating raw socket failed: %w", err) - } - - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) - } - - // Bind the socket to the "lo" interface. - err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") - if err != nil { - return nil, fmt.Errorf("binding to lo interface failed: %w", err) - } - - // Set the fwmark on the socket. - err = nbnet.SetSocketOpt(fd) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - - // Convert the file descriptor to a PacketConn. - file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) - if file == nil { - return nil, fmt.Errorf("converting fd to file failed") - } - packetConn, err := net.FilePacketConn(file) - if err != nil { - return nil, fmt.Errorf("converting file to packet conn failed: %w", err) - } - - return packetConn, nil -} - -func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { - localhost := net.ParseIP("127.0.0.1") - +func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { payload := gopacket.Payload(data) ipH := &layers.IPv4{ - DstIP: localhost, - SrcIP: localhost, + DstIP: localHostNetIP, + SrcIP: endpointAddr.IP, Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, } udpH := &layers.UDP{ - SrcPort: layers.UDPPort(port), + SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } @@ -276,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { if err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/proxy_test.go b/client/iface/wgproxy/ebpf/proxy_test.go index b15bc686c..3ec4f0eba 100644 --- a/client/iface/wgproxy/ebpf/proxy_test.go +++ b/client/iface/wgproxy/ebpf/proxy_test.go @@ -7,7 +7,7 @@ import ( ) func TestWGEBPFProxy_connStore(t *testing.T) { - wgProxy := NewWGEBPFProxy(1) + wgProxy := NewWGEBPFProxy(1, 1280) p, _ := wgProxy.storeTurnConn(nil) if p != 1 { @@ -27,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) { } func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { - wgProxy := NewWGEBPFProxy(1) + wgProxy := NewWGEBPFProxy(1, 1280) _, _ = wgProxy.storeTurnConn(nil) wgProxy.lastUsedPort = 65535 @@ -43,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { } func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { - wgProxy := NewWGEBPFProxy(1) + wgProxy := NewWGEBPFProxy(1, 1280) for i := 0; i < 65535; i++ { _, _ = wgProxy.storeTurnConn(nil) diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index b25dc4198..ff44d30c0 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -12,46 +12,48 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { - WgeBPFProxy *WGEBPFProxy + wgeBPFProxy *WGEBPFProxy remoteConn net.Conn ctx context.Context cancel context.CancelFunc - wgEndpointAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + wgEndpointCurrentUsedAddr *net.UDPAddr - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { +func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { return &ProxyWrapper{ - WgeBPFProxy: WgeBPFProxy, + wgeBPFProxy: proxy, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } } - func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) + addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - p.wgEndpointAddr = addr + p.wgRelayedEndpointAddr = addr return err } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { - return p.wgEndpointAddr + return p.wgRelayedEndpointAddr } func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { @@ -63,14 +65,18 @@ func (p *ProxyWrapper) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyWrapper) Pause() { @@ -79,45 +85,59 @@ func (p *ProxyWrapper) Pause() { } log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgEndpointCurrentUsedAddr = endpoint + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map -func (e *ProxyWrapper) CloseConn() error { - if e.cancel == nil { +func (p *ProxyWrapper) CloseConn() error { + if p.cancel == nil { return fmt.Errorf("proxy not started") } - e.cancel() + p.cancel() - e.closeListener.SetCloseListener(nil) + p.closeListener.SetCloseListener(nil) - if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("close remote conn: %w", err) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to close remote conn: %w", err) } return nil } func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { - defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) - buf := make([]byte, 1500) + buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) for { n, err := p.readFromRemote(ctx, buf) if err != nil { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) - p.pausedMu.Unlock() + err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { @@ -136,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } p.closeListener.Notify() if !errors.Is(err, io.EOF) { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) } return 0, err } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index e62cd97be..ad2807546 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -11,16 +11,18 @@ import ( type KernelFactory struct { wgPort int + mtu uint16 ebpfProxy *ebpf.WGEBPFProxy } -func NewKernelFactory(wgPort int) *KernelFactory { +func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { f := &KernelFactory{ wgPort: wgPort, + mtu: mtu, } - ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu) if err := ebpfProxy.Listen(); err != nil { log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) @@ -33,11 +35,10 @@ func NewKernelFactory(wgPort int) *KernelFactory { func (w *KernelFactory) GetProxy() Proxy { if w.ebpfProxy == nil { - return udpProxy.NewWGUDPProxy(w.wgPort) + return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu) } return ebpf.NewProxyWrapper(w.ebpfProxy) - } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go deleted file mode 100644 index 736944229..000000000 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ /dev/null @@ -1,29 +0,0 @@ -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" -) - -// KernelFactory todo: check eBPF support on FreeBSD -type KernelFactory struct { - wgPort int -} - -func NewKernelFactory(wgPort int) *KernelFactory { - log.Infof("WireGuard Proxy Factory will produce UDP proxy") - f := &KernelFactory{ - wgPort: wgPort, - } - - return f -} - -func (w *KernelFactory) GetProxy() Proxy { - return udpProxy.NewWGUDPProxy(w.wgPort) -} - -func (w *KernelFactory) Free() error { - return nil -} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index c2879877e..3c8dfd30e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -11,6 +11,11 @@ type Proxy interface { EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + + //RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused + //and rewrite the src address to the endpoint address. + //With this logic can avoid the package loss from relayed connections. + RedirectAs(endpoint *net.UDPAddr) CloseConn() error SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 298c98cc0..9526e91d2 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -3,54 +3,82 @@ package wgproxy import ( - "context" - "os" - "testing" + "fmt" + "net" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) -func TestProxyCloseByRemoteConnEBPF(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") != "true" { - t.Skip("Skipping test as it requires root privileges") - } - ctx := context.Background() +func seedProxies() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) - ebpfProxy := ebpf.NewWGEBPFProxy(51831) + ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - - tests := []struct { - name string - proxy Proxy - }{ - { - name: "ebpf proxy", - proxy: &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - }, - }, + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, } + pl = append(pl, pEbpf) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) - if err != nil { - t.Errorf("error: %v", err) - } - - _ = relayedConn.Close() - if err := tt.proxy.CloseConn(); err != nil { - t.Errorf("error: %v", err) - } - }) + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, } + pl = append(pl, pUDP) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + + ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) + if err := ebpfProxy.Listen(); err != nil { + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) + } + + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, + } + pl = append(pl, pEbpf) + + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, + } + pl = append(pl, pUDP) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + + return pl, nil } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go new file mode 100644 index 000000000..4d244f18a --- /dev/null +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -0,0 +1,39 @@ +//go:build !linux + +package wgproxy + +import ( + "net" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +func seedProxies() ([]proxyInstance, error) { + // todo extend with Bind proxy + pl := make([]proxyInstance, 0) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + return pl, nil +} diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 6882f9ea2..1aeab66b7 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -1,5 +1,3 @@ -//go:build linux - package wgproxy import ( @@ -7,12 +5,9 @@ import ( "io" "net" "os" - "runtime" "testing" "time" - "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -22,6 +17,14 @@ func TestMain(m *testing.M) { os.Exit(code) } +type proxyInstance struct { + name string + proxy Proxy + wgPort int + endpointAddr *net.UDPAddr + closeFn func() error +} + type mocConn struct { closeChan chan struct{} closed bool @@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error { func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() - tests := []struct { - name string - proxy Proxy - }{ - { - name: "userspace proxy", - proxy: udpProxy.NewWGUDPProxy(51830), - }, + tests, err := seedProxyForProxyCloseByRemoteConn() + if err != nil { + t.Fatalf("error: %v", err) } - if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { - ebpfProxy := ebpf.NewWGEBPFProxy(51831) - if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) - } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) - - tests = append(tests, struct { - name string - proxy Proxy - }{ - name: "ebpf proxy", - proxy: proxyWrapper, - }) - } + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + defer func() { + _ = relayedConn.Close() + }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892") relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) if err != nil { t.Errorf("error: %v", err) } @@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }) } } + +// TestProxyRedirect todo extend the proxies with Bind proxy +func TestProxyRedirect(t *testing.T) { + tests, err := seedProxies() + if err != nil { + t.Fatalf("error: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) + if err := tt.closeFn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} + +func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { + t.Helper() + + msgHelloFromRelay := []byte("hello from relay") + msgRedirected := [][]byte{ + []byte("hello 1. to p2p"), + []byte("hello 2. to p2p"), + []byte("hello 3. to p2p"), + } + + dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: wgPort}) + if err != nil { + t.Fatalf("failed to listen on udp port: %s", err) + } + + relayedServer, _ := net.ListenUDP("udp", + &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, + ) + + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + + defer func() { + _ = dummyWgListener.Close() + _ = relayedConn.Close() + _ = relayedServer.Close() + }() + + if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { + t.Errorf("error: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }() + + proxy.Work() + + if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil { + t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err) + } + + n, err := dummyWgListener.Read(make([]byte, 1024)) + if err != nil { + t.Errorf("error: %v", err) + } + + if n != len(msgHelloFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n) + } + + p2pEndpointAddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 56), + Port: 1234, + } + proxy.RedirectAs(p2pEndpointAddr) + + for _, msg := range msgRedirected { + if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil { + t.Errorf("error: %v", err) + } + } + + for i := 0; i < len(msgRedirected); i++ { + buf := make([]byte, 1024) + n, rAddr, err := dummyWgListener.ReadFrom(buf) + if err != nil { + t.Errorf("error: %v", err) + } + + if rAddr.String() != p2pEndpointAddr.String() { + t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String()) + } + if string(buf[:n]) != string(msgRedirected[i]) { + t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n])) + } + } +} diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go new file mode 100644 index 000000000..a11ac46d5 --- /dev/null +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -0,0 +1,50 @@ +//go:build linux && !android + +package rawsocket + +import ( + "fmt" + "net" + "os" + "syscall" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func PrepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return nil, fmt.Errorf("creating raw socket failed: %w", err) + } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) + if err != nil { + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + + // Bind the socket to the "lo" interface. + err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") + if err != nil { + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = nbnet.SetSocketOpt(fd) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) + } + + return packetConn, nil +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 139ccd4ed..4ef2f19c4 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package udp import ( @@ -12,32 +14,38 @@ import ( log "github.com/sirupsen/logrus" cerrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // WGUDPProxy proxies type WGUDPProxy struct { localWGListenPort int + mtu uint16 - remoteConn net.Conn - localConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + remoteConn net.Conn + localConn net.Conn + srcFakerConn *SrcFaker + sendPkg func(data []byte) (int, error) + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation -func NewWGUDPProxy(wgPort int) *WGUDPProxy { +func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) p := &WGUDPProxy{ localWGListenPort: wgPort, + mtu: mtu, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } return p @@ -58,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem p.ctx, p.cancel = context.WithCancel(ctx) p.localConn = localConn + p.sendPkg = p.localConn.Write p.remoteConn = remoteConn return err @@ -81,15 +90,24 @@ func (p *WGUDPProxy) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + p.sendPkg = p.localConn.Write + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } if !p.isStarted { p.isStarted = true go p.proxyToRemote(p.ctx) go p.proxyToLocal(p.ctx) } + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // Pause pauses the proxy from receiving data from the remote peer @@ -98,9 +116,35 @@ func (p *WGUDPProxy) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +// RedirectAs start to use the fake sourced raw socket as package sender +func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + defer func() { + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + }() + + p.paused = false + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } + srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create src faker conn: %s", err) + // fallback to continue without redirecting + p.paused = true + return + } + p.srcFakerConn = srcFakerConn + p.sendPkg = p.srcFakerConn.SendPkg } // CloseConn close the localConn @@ -112,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error { } func (p *WGUDPProxy) close() error { + var result *multierror.Error + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -125,7 +171,11 @@ func (p *WGUDPProxy) close() error { p.cancel() - var result *multierror.Error + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } @@ -133,6 +183,13 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err)) + } + } + return cerrors.FormatErrorOrNil(result) } @@ -144,7 +201,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { } }() - buf := make([]byte, 1500) + buf := make([]byte, p.mtu+bufsize.WGBufferOverhead) for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { @@ -179,7 +236,7 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { } }() - buf := make([]byte, 1500) + buf := make([]byte, p.mtu+bufsize.WGBufferOverhead) for { n, err := p.remoteConnRead(ctx, buf) if err != nil { @@ -191,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - - _, err = p.localConn.Write(buf[:n]) - p.pausedMu.Unlock() + _, err = p.sendPkg(buf[:n]) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go new file mode 100644 index 000000000..fdc911463 --- /dev/null +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -0,0 +1,101 @@ +//go:build linux && !android + +package udp + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" +) + +var ( + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + localHostNetIPAddr = &net.IPAddr{ + IP: net.ParseIP("127.0.0.1"), + } +) + +type SrcFaker struct { + srcAddr *net.UDPAddr + + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer +} + +func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { + rawSocket, err := rawsocket.PrepareSenderRawSocket() + if err != nil { + return nil, err + } + + ipH, udpH, err := prepareHeaders(dstPort, srcAddr) + if err != nil { + return nil, err + } + + f := &SrcFaker{ + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + } + + return f, nil +} + +func (f *SrcFaker) Close() error { + return f.rawSocket.Close() +} + +func (f *SrcFaker) SendPkg(data []byte) (int, error) { + defer func() { + if err := f.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload) + if err != nil { + return 0, fmt.Errorf("serialize layers: %w", err) + } + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + if err != nil { + return 0, fmt.Errorf("write to raw conn: %w", err) + } + return n, nil +} + +func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { + ipH := &layers.IPv4{ + DstIP: net.ParseIP("127.0.0.1"), + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(srcAddr.Port), + DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port + } + + err := udpH.SetNetworkLayerForChecksum(ipH) + if err != nil { + return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return ipH, udpH, nil +} diff --git a/client/internal/auth/device_flow_test.go b/client/internal/auth/device_flow_test.go index dc950ac63..466645ee9 100644 --- a/client/internal/auth/device_flow_test.go +++ b/client/internal/auth/device_flow_test.go @@ -3,15 +3,17 @@ package auth import ( "context" "fmt" - "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/client/internal" - "github.com/stretchr/testify/require" "io" "net/http" "net/url" "strings" "testing" "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal" ) type mockHTTPClient struct { diff --git a/client/internal/connect.go b/client/internal/connect.go index b62a2d951..295d35a43 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -33,7 +34,7 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) @@ -246,7 +247,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.statusRecorder.MarkSignalConnected() relayURLs, token := parseRelayInfo(loginResp) - relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) + peerConfig := loginResp.GetPeerConfig() + + engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) + if err != nil { + log.Error(err) + return wrapErr(err) + } + + relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { @@ -262,7 +271,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, c.LogFile) if err != nil { log.Error(err) @@ -276,7 +284,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engineMutex.Unlock() - if err := c.engine.Start(); err != nil { + if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } @@ -285,10 +293,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan state.Set(StatusConnected) if runningChan != nil { - select { - case runningChan <- struct{}{}: - default: - } + close(runningChan) + runningChan = nil } <-engineCtx.Done() @@ -447,8 +453,8 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf LazyConnectionEnabled: config.LazyConnectionEnabled, LogFile: logFile, - ProfileConfig: config, + MTU: selectMTU(config.MTU, peerConfig.Mtu), } if config.PreSharedKey != "" { @@ -471,6 +477,20 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf return engineConf, nil } +func selectMTU(localMTU uint16, peerMTU int32) uint16 { + var finalMTU uint16 = iface.DefaultMTU + if localMTU > 0 { + finalMTU = localMTU + } else if peerMTU > 0 { + finalMTU = uint16(peerMTU) + } + + // Set global DNS MTU + dns.SetCurrentMTU(finalMTU) + + return finalMTU +} + // connectToSignal creates Signal Service client and established a connection func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) { var sigTLSEnabled bool diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go new file mode 100644 index 000000000..cb651f1e5 --- /dev/null +++ b/client/internal/dns/config/domains.go @@ -0,0 +1,201 @@ +package config + +import ( + "errors" + "fmt" + "net" + "net/netip" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/shared/management/domain" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +var ( + ErrEmptyURL = errors.New("empty URL") + ErrEmptyHost = errors.New("empty host") + ErrIPNotAllowed = errors.New("IP address not allowed") +) + +// ServerDomains represents the management server domains extracted from NetBird configuration +type ServerDomains struct { + Signal domain.Domain + Relay []domain.Domain + Flow domain.Domain + Stuns []domain.Domain + Turns []domain.Domain +} + +// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration +func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { + if config == nil { + return ServerDomains{} + } + + domains := ServerDomains{} + + domains.Signal = extractSignalDomain(config) + domains.Relay = extractRelayDomains(config) + domains.Flow = extractFlowDomain(config) + domains.Stuns = extractStunDomains(config) + domains.Turns = extractTurnDomains(config) + + return domains +} + +// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses +func ExtractValidDomain(rawURL string) (domain.Domain, error) { + if rawURL == "" { + return "", ErrEmptyURL + } + + parsedURL, err := url.Parse(rawURL) + if err == nil { + if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" { + return domain, err + } + } + + return extractFromRawString(rawURL) +} + +// extractFromParsedURL handles domain extraction from successfully parsed URLs +func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) { + if parsedURL.Hostname() != "" { + return extractDomainFromHost(parsedURL.Hostname()) + } + + if parsedURL.Opaque == "" || parsedURL.Scheme == "" { + return "", nil + } + + // Handle URLs with opaque content (e.g., stun:host:port) + if strings.Contains(parsedURL.Scheme, ".") { + // This is likely "domain.com:port" being parsed as scheme:opaque + reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque + if host, _, err := net.SplitHostPort(reconstructed); err == nil { + return extractDomainFromHost(host) + } + return extractDomainFromHost(parsedURL.Scheme) + } + + // Valid scheme with opaque content (e.g., stun:host:port) + host := parsedURL.Opaque + if queryIndex := strings.Index(host, "?"); queryIndex > 0 { + host = host[:queryIndex] + } + + if hostOnly, _, err := net.SplitHostPort(host); err == nil { + return extractDomainFromHost(hostOnly) + } + + return extractDomainFromHost(host) +} + +// extractFromRawString handles domain extraction when URL parsing fails or returns no results +func extractFromRawString(rawURL string) (domain.Domain, error) { + if host, _, err := net.SplitHostPort(rawURL); err == nil { + return extractDomainFromHost(host) + } + + return extractDomainFromHost(rawURL) +} + +// extractDomainFromHost extracts domain from a host string, filtering out IP addresses +func extractDomainFromHost(host string) (domain.Domain, error) { + if host == "" { + return "", ErrEmptyHost + } + + if _, err := netip.ParseAddr(host); err == nil { + return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) + } + + d, err := domain.FromString(host) + if err != nil { + return "", fmt.Errorf("invalid domain: %v", err) + } + + return d, nil +} + +// extractSingleDomain extracts a single domain from a URL with error logging +func extractSingleDomain(url, serviceType string) domain.Domain { + if url == "" { + return "" + } + + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + return "" + } + + return d +} + +// extractMultipleDomains extracts multiple domains from URLs with error logging +func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { + var domains []domain.Domain + for _, url := range urls { + if url == "" { + continue + } + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + continue + } + domains = append(domains, d) + } + return domains +} + +// extractSignalDomain extracts the signal domain from NetBird configuration. +func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Signal != nil { + return extractSingleDomain(config.Signal.Uri, "signal") + } + return "" +} + +// extractRelayDomains extracts relay server domains from NetBird configuration. +func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Relay != nil { + return extractMultipleDomains(config.Relay.Urls, "relay") + } + return nil +} + +// extractFlowDomain extracts the traffic flow domain from NetBird configuration. +func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Flow != nil { + return extractSingleDomain(config.Flow.Url, "flow") + } + return "" +} + +// extractStunDomains extracts STUN server domains from NetBird configuration. +func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, stun := range config.Stuns { + if stun != nil && stun.Uri != "" { + urls = append(urls, stun.Uri) + } + } + return extractMultipleDomains(urls, "STUN") +} + +// extractTurnDomains extracts TURN server domains from NetBird configuration. +func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, turn := range config.Turns { + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { + urls = append(urls, turn.HostConfig.Uri) + } + } + return extractMultipleDomains(urls, "TURN") +} diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go new file mode 100644 index 000000000..5eae3a541 --- /dev/null +++ b/client/internal/dns/config/domains_test.go @@ -0,0 +1,213 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractValidDomain(t *testing.T) { + tests := []struct { + name string + url string + expected string + expectError bool + }{ + { + name: "HTTPS URL with port", + url: "https://api.netbird.io:443", + expected: "api.netbird.io", + }, + { + name: "HTTP URL without port", + url: "http://signal.example.com", + expected: "signal.example.com", + }, + { + name: "Host with port (no scheme)", + url: "signal.netbird.io:443", + expected: "signal.netbird.io", + }, + { + name: "STUN URL", + url: "stun:stun.netbird.io:443", + expected: "stun.netbird.io", + }, + { + name: "STUN URL with different port", + url: "stun:stun.netbird.io:5555", + expected: "stun.netbird.io", + }, + { + name: "TURNS URL with query params", + url: "turns:turn.netbird.io:443?transport=tcp", + expected: "turn.netbird.io", + }, + { + name: "TURN URL", + url: "turn:turn.example.com:3478", + expected: "turn.example.com", + }, + { + name: "REL URL", + url: "rel://relay.example.com:443", + expected: "relay.example.com", + }, + { + name: "RELS URL", + url: "rels://relay.netbird.io:443", + expected: "relay.netbird.io", + }, + { + name: "Raw hostname", + url: "example.org", + expected: "example.org", + }, + { + name: "IP address should be rejected", + url: "192.168.1.1", + expectError: true, + }, + { + name: "IP address with port should be rejected", + url: "192.168.1.1:443", + expectError: true, + }, + { + name: "IPv6 address should be rejected", + url: "2001:db8::1", + expectError: true, + }, + { + name: "HTTP URL with IPv4 should be rejected", + url: "http://192.168.1.1:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv4 should be rejected", + url: "https://10.0.0.1:443", + expectError: true, + }, + { + name: "STUN URL with IPv4 should be rejected", + url: "stun:192.168.1.1:3478", + expectError: true, + }, + { + name: "TURN URL with IPv4 should be rejected", + url: "turn:10.0.0.1:3478", + expectError: true, + }, + { + name: "TURNS URL with IPv4 should be rejected", + url: "turns:172.16.0.1:5349", + expectError: true, + }, + { + name: "HTTP URL with IPv6 should be rejected", + url: "http://[2001:db8::1]:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv6 should be rejected", + url: "https://[::1]:443", + expectError: true, + }, + { + name: "STUN URL with IPv6 should be rejected", + url: "stun:[2001:db8::1]:3478", + expectError: true, + }, + { + name: "IPv6 with port should be rejected", + url: "[2001:db8::1]:443", + expectError: true, + }, + { + name: "Localhost IPv4 should be rejected", + url: "127.0.0.1:8080", + expectError: true, + }, + { + name: "Localhost IPv6 should be rejected", + url: "[::1]:443", + expectError: true, + }, + { + name: "REL URL with IPv4 should be rejected", + url: "rel://192.168.1.1:443", + expectError: true, + }, + { + name: "RELS URL with IPv4 should be rejected", + url: "rels://10.0.0.1:443", + expectError: true, + }, + { + name: "Empty URL", + url: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ExtractValidDomain(tt.url) + + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Unexpected error for URL: %s", tt.url) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url) + } + }) + } +} + +func TestExtractDomainFromHost(t *testing.T) { + tests := []struct { + name string + host string + expected string + expectError bool + }{ + { + name: "Valid domain", + host: "example.com", + expected: "example.com", + }, + { + name: "Subdomain", + host: "api.example.com", + expected: "api.example.com", + }, + { + name: "IPv4 address", + host: "192.168.1.1", + expectError: true, + }, + { + name: "IPv6 address", + host: "2001:db8::1", + expectError: true, + }, + { + name: "Empty host", + host: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractDomainFromHost(tt.host) + + if tt.expectError { + assert.Error(t, err, "Expected error for host: %s", tt.host) + } else { + assert.NoError(t, err, "Unexpected error for host: %s", tt.host) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host) + } + }) + } +} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 439bcbb3c..2e54bffd9 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,11 +11,12 @@ import ( ) const ( - PriorityLocal = 100 - PriorityDNSRoute = 75 - PriorityUpstream = 50 - PriorityDefault = 1 - PriorityFallback = -100 + PriorityMgmtCache = 150 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants to continue, try next handler if chainWriter.shouldContinue { - log.Tracef("handler requested continue to next handler for domain=%s", qname) + // Only log continue for non-management cache handlers to reduce noise + if entry.Priority != PriorityMgmtCache { + log.Tracef("handler requested continue to next handler for domain=%s", qname) + } continue } return diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 852dfef48..b06ba73ab 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -166,9 +166,10 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { - err := s.recordSystemDNSSettings(true) - log.Errorf("Unable to get system DNS configuration") - return err + if err := s.recordSystemDNSSettings(true); err != nil { + log.Errorf("Unable to get system DNS configuration") + return fmt.Errorf("recordSystemDNSSettings(): %w", err) + } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index fdc2c3063..0d3f033fb 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { - policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - if r.gpo { - policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - } singleDomain := []string{domain} - if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) + if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + } + + if r.gpo { + if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure gpo DNS policy: %w", err) + } } log.Debugf("added NRPT entry for domain: %s", domain) @@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) } @@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) } diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b776fbbe3..bac7875ec 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool { // String returns a string representation of the local resolver func (d *Resolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.records)) + return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) } func (d *Resolver) Stop() {} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go new file mode 100644 index 000000000..290395473 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt.go @@ -0,0 +1,360 @@ +package mgmt + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/shared/management/domain" +) + +const dnsTimeout = 5 * time.Second + +// Resolver caches critical NetBird infrastructure domains +type Resolver struct { + records map[dns.Question][]dns.RR + mgmtDomain *domain.Domain + serverDomains *dnsconfig.ServerDomains + mutex sync.RWMutex +} + +// NewResolver creates a new management domains cache resolver. +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +// String returns a string representation of the resolver. +func (m *Resolver) String() string { + return "MgmtCacheResolver" +} + +// ServeDNS implements dns.Handler interface. +func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + m.continueToNext(w, r) + return + } + + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { + m.continueToNext(w, r) + return + } + + m.mutex.RLock() + records, found := m.records[question] + m.mutex.RUnlock() + + if !found { + m.continueToNext(w, r) + return + } + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + + resp.Answer = append(resp.Answer, records...) + + log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write response: %v", err) + } +} + +// MatchSubdomains returns false since this resolver only handles exact domain matches +// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. +func (m *Resolver) MatchSubdomains() bool { + return false +} + +// continueToNext signals the handler chain to continue to the next handler. +func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write continue signal: %v", err) + } +} + +// AddDomain manually adds a domain to cache by resolving it. +func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + if err != nil { + return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + } + + var aRecords, aaaaRecords []dns.RR + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) + } + } + + m.mutex.Lock() + + if len(aRecords) > 0 { + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + m.records[aQuestion] = aRecords + } + + if len(aaaaRecords) > 0 { + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + m.records[aaaaQuestion] = aaaaRecords + } + + m.mutex.Unlock() + + log.Debugf("added domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + + return nil +} + +// PopulateFromConfig extracts and caches domains from the client configuration. +func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { + if mgmtURL == nil { + return nil + } + + d, err := dnsconfig.ExtractValidDomain(mgmtURL.String()) + if err != nil { + return fmt.Errorf("extract domain from URL: %w", err) + } + + m.mutex.Lock() + m.mgmtDomain = &d + m.mutex.Unlock() + + if err := m.AddDomain(ctx, d); err != nil { + return fmt.Errorf("add domain: %w", err) + } + + return nil +} + +// RemoveDomain removes a domain from the cache. +func (m *Resolver) RemoveDomain(d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + m.mutex.Lock() + defer m.mutex.Unlock() + + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + delete(m.records, aQuestion) + + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + delete(m.records, aaaaQuestion) + + log.Debugf("removed domain=%s from cache", d.SafeString()) + return nil +} + +// GetCachedDomains returns a list of all cached domains. +func (m *Resolver) GetCachedDomains() domain.List { + m.mutex.RLock() + defer m.mutex.RUnlock() + + domainSet := make(map[domain.Domain]struct{}) + for question := range m.records { + domainName := strings.TrimSuffix(question.Name, ".") + domainSet[domain.Domain(domainName)] = struct{}{} + } + + domains := make(domain.List, 0, len(domainSet)) + for d := range domainSet { + domains = append(domains, d) + } + + return domains +} + +// UpdateFromServerDomains updates the cache with server domains from network configuration. +// It merges new domains with existing ones, replacing entire domain types when updated. +// Empty updates are ignored to prevent clearing infrastructure domains during partial updates. +func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) { + newDomains := m.extractDomainsFromServerDomains(serverDomains) + var removedDomains domain.List + + if len(newDomains) > 0 { + m.mutex.Lock() + if m.serverDomains == nil { + m.serverDomains = &dnsconfig.ServerDomains{} + } + updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains) + m.serverDomains = &updatedServerDomains + m.mutex.Unlock() + + allDomains := m.extractDomainsFromServerDomains(updatedServerDomains) + currentDomains := m.GetCachedDomains() + removedDomains = m.removeStaleDomains(currentDomains, allDomains) + } + + m.addNewDomains(ctx, newDomains) + + return removedDomains, nil +} + +// removeStaleDomains removes cached domains not present in the target domain list. +// Management domains are preserved and never removed during server domain updates. +func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List { + var removedDomains domain.List + + for _, currentDomain := range currentDomains { + if m.isDomainInList(currentDomain, newDomains) { + continue + } + + if m.isManagementDomain(currentDomain) { + continue + } + + removedDomains = append(removedDomains, currentDomain) + if err := m.RemoveDomain(currentDomain); err != nil { + log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) + } + } + + return removedDomains +} + +// mergeServerDomains merges new server domains with existing ones. +// When a domain type is provided in the new domains, it completely replaces that type. +func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains { + merged := existing + + if incoming.Signal != "" { + merged.Signal = incoming.Signal + } + if len(incoming.Relay) > 0 { + merged.Relay = incoming.Relay + } + if incoming.Flow != "" { + merged.Flow = incoming.Flow + } + if len(incoming.Stuns) > 0 { + merged.Stuns = incoming.Stuns + } + if len(incoming.Turns) > 0 { + merged.Turns = incoming.Turns + } + + return merged +} + +// isDomainInList checks if domain exists in the list +func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool { + for _, d := range list { + if domain.SafeString() == d.SafeString() { + return true + } + } + return false +} + +// isManagementDomain checks if domain is the protected management domain +func (m *Resolver) isManagementDomain(domain domain.Domain) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return m.mgmtDomain != nil && domain == *m.mgmtDomain +} + +// addNewDomains resolves and caches all domains from the update +func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { + for _, newDomain := range newDomains { + if err := m.AddDomain(ctx, newDomain); err != nil { + log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) + } else { + log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) + } + } +} + +func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List { + var domains domain.List + + if serverDomains.Signal != "" { + domains = append(domains, serverDomains.Signal) + } + + for _, relay := range serverDomains.Relay { + if relay != "" { + domains = append(domains, relay) + } + } + + if serverDomains.Flow != "" { + domains = append(domains, serverDomains.Flow) + } + + for _, stun := range serverDomains.Stuns { + if stun != "" { + domains = append(domains, stun) + } + } + + for _, turn := range serverDomains.Turns { + if turn != "" { + domains = append(domains, turn) + } + } + + return domains +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go new file mode 100644 index 000000000..99d289871 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -0,0 +1,416 @@ +package mgmt + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +func TestResolver_NewResolver(t *testing.T) { + resolver := NewResolver() + + assert.NotNil(t, resolver) + assert.NotNil(t, resolver.records) + assert.False(t, resolver.MatchSubdomains()) +} + +func TestResolver_ExtractDomainFromURL(t *testing.T) { + tests := []struct { + name string + urlStr string + expectedDom string + expectError bool + }{ + { + name: "HTTPS URL with port", + urlStr: "https://api.netbird.io:443", + expectedDom: "api.netbird.io", + expectError: false, + }, + { + name: "HTTP URL without port", + urlStr: "http://signal.example.com", + expectedDom: "signal.example.com", + expectError: false, + }, + { + name: "URL with path", + urlStr: "https://relay.netbird.io/status", + expectedDom: "relay.netbird.io", + expectError: false, + }, + { + name: "Invalid URL", + urlStr: "not-a-valid-url", + expectedDom: "not-a-valid-url", + expectError: false, + }, + { + name: "Empty URL", + urlStr: "", + expectedDom: "", + expectError: true, + }, + { + name: "STUN URL", + urlStr: "stun:stun.example.com:3478", + expectedDom: "stun.example.com", + expectError: false, + }, + { + name: "TURN URL", + urlStr: "turn:turn.example.com:3478", + expectedDom: "turn.example.com", + expectError: false, + }, + { + name: "REL URL", + urlStr: "rel://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + { + name: "RELS URL", + urlStr: "rels://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parsedURL *url.URL + var err error + + if tt.urlStr != "" { + parsedURL, err = url.Parse(tt.urlStr) + if err != nil && !tt.expectError { + t.Fatalf("Failed to parse URL: %v", err) + } + } + + domain, err := extractDomainFromURL(parsedURL) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedDom, domain.SafeString()) + } + }) + } +} + +func TestResolver_PopulateFromConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := NewResolver() + + // Test with IP address - should return error since IP addresses are rejected + mgmtURL, _ := url.Parse("https://127.0.0.1") + + err := resolver.PopulateFromConfig(ctx, mgmtURL) + assert.Error(t, err) + assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed) + + // No domains should be cached when using IP addresses + domains := resolver.GetCachedDomains() + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") +} + +func TestResolver_ServeDNS(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Add a test domain to the cache - use example.org which is reserved for testing + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Test A record query for cached domain + t.Run("Cached domain A record", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer") + }) + + // Test uncached domain signals to continue to next handler + t.Run("Uncached domain signals continue to next handler", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("unknown.example.com.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + // Zero flag set to true signals the handler chain to continue to next handler + assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain") + }) + + // Test that subdomains of cached domains are NOT resolved + t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query for a subdomain of our cached domain + req := new(dns.Msg) + req.SetQuestion("sub.example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains") + }) + + // Test case-insensitive matching + t.Run("Case-insensitive domain matching", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query with different casing + req := new(dns.Msg) + req.SetQuestion("EXAMPLE.ORG.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case") + }) +} + +func TestResolver_GetCachedDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + cachedDomains := resolver.GetCachedDomains() + + assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain") + assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original") + assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot") +} + +func TestResolver_ManagementDomainProtection(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + mgmtURL, _ := url.Parse("https://example.org") + err := resolver.PopulateFromConfig(ctx, mgmtURL) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + initialDomains := resolver.GetCachedDomains() + if len(initialDomains) == 0 { + t.Skip("Management domain failed to resolve, skipping test") + } + assert.Equal(t, 1, len(initialDomains), "Should have management domain cached") + assert.Equal(t, "example.org", initialDomains[0].SafeString()) + + serverDomains := dnsconfig.ServerDomains{ + Signal: "google.com", + Relay: []domain.Domain{"cloudflare.com"}, + } + + _, err = resolver.UpdateFromServerDomains(ctx, serverDomains) + if err != nil { + t.Logf("Server domains update failed: %v", err) + } + + finalDomains := resolver.GetCachedDomains() + + managementStillCached := false + for _, d := range finalDomains { + if d.SafeString() == "example.org" { + managementStillCached = true + break + } + } + assert.True(t, managementStillCached, "Management domain should never be removed") +} + +// extractDomainFromURL extracts a domain from a URL - test helper function +func extractDomainFromURL(u *url.URL) (domain.Domain, error) { + if u == nil { + return "", fmt.Errorf("URL is nil") + } + return dnsconfig.ExtractValidDomain(u.String()) +} + +func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Verify domains were added + cachedDomains := resolver.GetCachedDomains() + assert.Len(t, cachedDomains, 3) + + // Update with empty ServerDomains (simulating partial network map update) + emptyDomains := dnsconfig.ServerDomains{} + removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains) + assert.NoError(t, err) + + // Verify no domains were removed + assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty") + + // Verify all original domains are still cached + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "All original domains should still be cached") +} + +func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn) + partialDomains := dnsconfig.ServerDomains{ + Signal: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Should remove only the old signal domain + assert.Len(t, removedDomains, 1, "Should remove only the old signal domain") + assert.Equal(t, "example.org", removedDomains[0].SafeString()) + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "github.com") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.NotContains(t, domainStrings, "example.org") +} + +func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only flow domain - new type, should preserve all existing) + partialDomains := dnsconfig.ServerDomains{ + Flow: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "example.org") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.Contains(t, domainStrings, "github.com") +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index d160fa99a..0f89b9016 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,20 +3,23 @@ package dns import ( "fmt" "net/netip" + "net/url" "github.com/miekg/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/shared/management/domain" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func(domain.List, dns.Handler, int) - DeregisterHandlerFunc func(domain.List, int) + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) + UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error } func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string { // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface func (m *MockServer) ProbeAvailability() { } + +func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + if m.UpdateServerConfigFunc != nil { + return m.UpdateServerConfigFunc(domains) + } + return nil +} + +func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { + return nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index cbcf6a256..8cb886203 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "net/url" "runtime" "strings" "sync" @@ -15,7 +16,9 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -45,6 +48,8 @@ type Server interface { OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string ProbeAvailability() + UpdateServerConfig(domains dnsconfig.ServerDomains) error + PopulateManagementDomain(mgmtURL *url.URL) error } type nsGroupsByDomain struct { @@ -77,6 +82,8 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int + mgmtCacheResolver *mgmt.Resolver + // permanent related properties permanent bool hostsDNSHolder *hostsDNSHolder @@ -104,18 +111,20 @@ type handlerWrapper struct { type registeredHandlerMap map[types.HandlerID]handlerWrapper +// DefaultServerConfig holds configuration parameters for NewDefaultServer +type DefaultServerConfig struct { + WgInterface WGIface + CustomAddress string + StatusRecorder *peer.Status + StateManager *statemanager.Manager + DisableSys bool +} + // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, - stateManager *statemanager.Manager, - disableSys bool, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) { var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if config.CustomAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) if err != nil { return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) } @@ -123,13 +132,14 @@ func NewDefaultServer( } var dnsService service - if wgInterface.IsUserspaceBind() { - dnsService = NewServiceViaMemory(wgInterface) + if config.WgInterface.IsUserspaceBind() { + dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(wgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil + server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) + return server, nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -178,20 +188,24 @@ func newDefaultServer( ) *DefaultServer { handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) + + mgmtCacheResolver := mgmt.NewResolver() + defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: local.NewResolver(), - wgInterface: wgInterface, - statusRecorder: statusRecorder, - stateManager: stateManager, - hostsDNSHolder: newHostsDNSHolder(), - hostManager: &noopHostConfigurator{}, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), + wgInterface: wgInterface, + statusRecorder: statusRecorder, + stateManager: stateManager, + hostsDNSHolder: newHostsDNSHolder(), + hostManager: &noopHostConfigurator{}, + mgmtCacheResolver: mgmtCacheResolver, } // register with root zone, handler chain takes care of the routing @@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { - log.Debugf("registering handler %s with priority %d", handler, priority) + log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains) for _, domain := range domains { if domain == "" { @@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { - log.Debugf("deregistering handler %v with priority %d", domains, priority) + log.Debugf("deregistering handler with priority %d for %v", priority, domains) for _, domain := range domains { if domain == "" { @@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } +func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + s.mux.Lock() + defer s.mux.Unlock() + + if s.mgmtCacheResolver != nil { + removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) + if err != nil { + return fmt.Errorf("update management cache resolver: %w", err) + } + + if len(removedDomains) > 0 { + s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) + } + + newDomains := s.mgmtCacheResolver.GetCachedDomains() + if len(newDomains) > 0 { + s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) + } + } + + return nil +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver if update.ServiceEnable { @@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain { ), ) } + +// PopulateManagementDomain populates the DNS cache with management domain +func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { + if s.mgmtCacheResolver != nil { + return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) + } + return nil +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 068f001d8..11575d500 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatal(err) } @@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: &mocWGIface{}, + CustomAddress: testCase.addrPort, + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 89d637686..6ef0ab526 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type ServiceViaMemory struct { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index f5d0e775f..c19e0acb5 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -26,10 +26,18 @@ import ( "github.com/netbirdio/netbird/client/proto" ) -const ( - UpstreamTimeout = 15 * time.Second +var currentMTU uint16 = iface.DefaultMTU + +func SetCurrentMTU(mtu uint16) { + currentMTU = mtu +} + +const ( + UpstreamTimeout = 4 * time.Second + // ClientTimeout is the timeout for the dns.Client. + // Set longer than UpstreamTimeout to ensure context timeout takes precedence + ClientTimeout = 5 * time.Second - failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second ) @@ -52,9 +60,7 @@ type upstreamResolverBase struct { upstreamServers []netip.AddrPort domain string disabled bool - failsCount atomic.Int32 successCount atomic.Int32 - failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration @@ -73,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain: domain, upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.upstreamServers) } // ID returns the unique handler ID @@ -110,58 +115,102 @@ func (u *upstreamResolverBase) Stop() { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { requestID := GenerateRequestID() logger := log.WithField("request_id", requestID) - var err error - defer func() { - u.checkUpstreamFails(err) - }() logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + u.prepareRequest(r) + + if u.ctx.Err() != nil { + logger.Tracef("%s has been stopped", u) + return + } + + if u.tryUpstreamServers(w, r, logger) { + return + } + + u.writeErrorResponse(w, r, logger) +} + +func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } +} - select { - case <-u.ctx.Done(): - logger.Tracef("%s has been stopped", u) - return - default: +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { + timeout := u.upstreamTimeout + if len(u.upstreamServers) > 1 { + maxTotal := 5 * time.Second + minPerUpstream := 2 * time.Second + scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) + if scaledTimeout > minPerUpstream { + timeout = scaledTimeout + } else { + timeout = minPerUpstream + } } for _, upstream := range u.upstreamServers { - var rm *dns.Msg - var t time.Duration - - func() { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - defer cancel() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) - }() - - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) - continue - } - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) - continue + if u.queryUpstream(w, r, upstream, timeout, logger) { + return true } + } + return false +} - if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - continue - } +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { + var rm *dns.Msg + var t time.Duration + var err error - u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + var startTime time.Time + func() { + ctx, cancel := context.WithTimeout(u.ctx, timeout) + defer cancel() + startTime = time.Now() + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + }() - if err = w.WriteMsg(rm); err != nil { - logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) - } - // count the fails only if they happen sequentially - u.failsCount.Store(0) + if err != nil { + u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) + return false + } + + if rm == nil || !rm.Response { + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + return false + } + + return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) +} + +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { + if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) return } - u.failsCount.Add(1) + + elapsed := time.Since(startTime) + timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { + timeoutMsg += " " + peerInfo + } + timeoutMsg += fmt.Sprintf(" - error: %v", err) + logger.Warnf(timeoutMsg) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { + u.successCount.Add(1) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) + + if err := w.WriteMsg(rm); err != nil { + logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) + } + return true +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) @@ -171,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -// checkUpstreamFails counts fails and disables or enables upstream resolving -// -// If fails count is greater that failsTillDeact, upstream resolving -// will be disabled for reactivatePeriod, after that time period fails counter -// will be reset and upstream will be reactivated. -func (u *upstreamResolverBase) checkUpstreamFails(err error) { - u.mutex.Lock() - defer u.mutex.Unlock() - - if u.failsCount.Load() < u.failsTillDeact || u.disabled { - return - } - - select { - case <-u.ctx.Done(): - return - default: - } - - u.disable(err) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (fail count exceeded)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - // TODO add domain meta - ) -} - // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -218,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() { default: } - // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact - if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { + // avoid probe if upstreams could resolve at least one query + if u.successCount.Load() > 0 { return } @@ -306,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() { } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.failsCount.Store(0) u.successCount.Add(1) u.reactivate() u.disabled = false @@ -358,8 +371,8 @@ func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout tim // If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { // MTU - ip + udp headers - // Note: this could be sent out on an interface that is not ours, but our MTU should always be lower. - client.UDPSize = iface.DefaultMTU - (60 + 8) + // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling. + client.UDPSize = uint16(currentMTU - (60 + 8)) var ( rm *dns.Msg @@ -410,3 +423,80 @@ func GenerateRequestID() string { } return hex.EncodeToString(bytes) } + +// FormatPeerStatus formats peer connection status information for debugging DNS timeouts +func FormatPeerStatus(peerState *peer.State) string { + isConnected := peerState.ConnStatus == peer.StatusConnected + hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() && + time.Since(peerState.LastWireguardHandshake) < 3*time.Minute + + statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP) + + switch { + case !isConnected: + statusInfo += " DISCONNECTED" + case !hasRecentHandshake: + statusInfo += " NO_RECENT_HANDSHAKE" + default: + statusInfo += " connected" + } + + if !peerState.LastWireguardHandshake.IsZero() { + timeSinceHandshake := time.Since(peerState.LastWireguardHandshake) + statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second)) + } else { + statusInfo += " no_handshake" + } + + if peerState.Relayed { + statusInfo += " via_relay" + } + + if peerState.Latency > 0 { + statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency) + } + + return statusInfo +} + +// findPeerForIP finds which peer handles the given IP address +func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { + if statusRecorder == nil { + return nil + } + + fullStatus := statusRecorder.GetFullStatus() + var bestMatch *peer.State + var bestPrefixLen int + + for _, peerState := range fullStatus.Peers { + routes := peerState.GetRoutes() + for route := range routes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + continue + } + + if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen { + peerStateCopy := peerState + bestMatch = &peerStateCopy + bestPrefixLen = prefix.Bits() + } + } + } + + return bestMatch +} + +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ddbf84ae4..def281f28 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type upstreamResolver struct { @@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns } func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{} + upstreamExchangeClient := &dns.Client{ + Timeout: ClientTimeout, + } return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } @@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri } upstreamExchangeClient := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: timeout, } - return upstreamExchangeClient.Exchange(r, upstream) + return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 317588a27..434e5880b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -34,7 +34,10 @@ func newUpstreamResolver( } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) + client := &dns.Client{ + Timeout: ClientTimeout, + } + return ExchangeWithFallback(ctx, client, r, upstream) } func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 96b8bbb0f..eadcdd117 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -47,7 +47,9 @@ func newUpstreamResolver( } func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - client := &dns.Client{} + client := &dns.Client{ + Timeout: ClientTimeout, + } upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) @@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura }, } client := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: dialTimeout, } return client, nil } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 51d870e2a..e1573e75e 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) } func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { + mockClient := &mockUpstreamResolver{ + err: dns.ErrTime, + r: new(dns.Msg), + rtt: time.Millisecond, + } + resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: &mockUpstreamResolver{ - err: nil, - r: new(dns.Msg), - rtt: time.Millisecond, - }, + ctx: context.TODO(), + upstreamClient: mockClient, upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, + reactivatePeriod: time.Microsecond * 100, } addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - resolver.failsTillDeact = 0 - resolver.reactivatePeriod = time.Microsecond * 100 - - responseWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { return nil }, - } failed := false resolver.deactivate = func(error) { failed = true + // After deactivation, make the mock client work again + mockClient.err = nil } reactivated := false @@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) + resolver.ProbeAvailability() if !failed { t.Errorf("expected that resolving was deactivated") @@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { return } - if resolver.failsCount.Load() != 0 { - t.Errorf("fails count after reactivation should be 0") - return - } - if resolver.disabled { t.Errorf("should be enabled") } diff --git a/client/internal/engine.go b/client/internal/engine.go index 4e847758d..fefe2e96c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "net/netip" + "net/url" "os" "reflect" "runtime" @@ -17,8 +18,8 @@ import ( "time" "github.com/hashicorp/go-multierror" - "github.com/pion/ice/v3" - "github.com/pion/stun/v2" + "github.com/pion/ice/v4" + "github.com/pion/stun/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -28,12 +29,13 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -134,6 +136,7 @@ type EngineConfig struct { ProfileConfig *profilemanager.Config LogFile string + MTU uint16 } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -171,7 +174,7 @@ type Engine struct { wgInterface WGIface - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -207,6 +210,10 @@ type Engine struct { jobExecutor *jobexec.Executor jobExecutorWG sync.WaitGroup + + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + wgIfaceMonitorWg sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -343,16 +350,23 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + // Stop WireGuard interface monitor and wait for it to exit + e.wgIfaceMonitorWg.Wait() + return nil } // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service -func (e *Engine) Start() error { +func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() + if err := iface.ValidateMTU(e.config.MTU); err != nil { + return fmt.Errorf("invalid MTU configuration: %w", err) + } + if e.cancel != nil { e.cancel() } @@ -401,6 +415,11 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer + // Populate DNS cache with NetbirdConfig and management URL for early resolution + if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache: %v", err) + } + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ Context: e.ctx, PublicKey: e.config.WgPrivateKey.PublicKey().String(), @@ -439,6 +458,8 @@ func (e *Engine) Start() error { return fmt.Errorf("up wg interface: %w", err) } + + // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) @@ -454,7 +475,7 @@ func (e *Engine) Start() error { StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), } @@ -471,6 +492,22 @@ func (e *Engine) Start() error { // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() + + // monitor WireGuard interface lifecycle and restart engine on changes + e.wgIfaceMonitor = NewWGIfaceMonitor() + e.wgIfaceMonitorWg.Add(1) + + go func() { + defer e.wgIfaceMonitorWg.Done() + + if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { + log.Infof("WireGuard interface monitor: %s, restarting engine", err) + e.restartEngine() + } else if err != nil { + log.Warnf("WireGuard interface monitor: %s", err) + } + }() + return nil } @@ -662,6 +699,30 @@ func (e *Engine) removePeer(peerKey string) error { return nil } +// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response +func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { + if e.dnsServer == nil { + return nil + } + + // Populate management URL if provided + if mgmtURL != nil { + if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache with management URL: %v", err) + } + } + + // Populate NetbirdConfig domains if provided + if netbirdConfig != nil { + serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig) + if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { + return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err) + } + } + + return nil +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -693,6 +754,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } + if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { + log.Warnf("Failed to update DNS server config: %v", err) + } + // todo update signal } @@ -1001,7 +1066,6 @@ func (e *Engine) receiveManagementEvents() { e.config.LazyConnectionEnabled, ) - // err = e.mgmClient.Sync(info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. @@ -1012,7 +1076,7 @@ func (e *Engine) receiveManagementEvents() { } log.Debugf("stopped receiving updates from Management Service") }() - log.Debugf("connecting to Management Service updates stream") + log.Infof("connecting to Management Service updates stream") } func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { @@ -1204,15 +1268,16 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { } convertedRoute := &route.Route{ - ID: route.ID(protoRoute.ID), - Network: prefix.Masked(), - Domains: domain.FromPunycodeList(protoRoute.Domains), - NetID: route.NetID(protoRoute.NetID), - NetworkType: route.NetworkType(protoRoute.NetworkType), - Peer: protoRoute.Peer, - Metric: int(protoRoute.Metric), - Masquerade: protoRoute.Masquerade, - KeepRoute: protoRoute.KeepRoute, + ID: route.ID(protoRoute.ID), + Network: prefix.Masked(), + Domains: domain.FromPunycodeList(protoRoute.Domains), + NetID: route.NetID(protoRoute.NetID), + NetworkType: route.NetworkType(protoRoute.NetworkType), + Peer: protoRoute.Peer, + Metric: int(protoRoute.Metric), + Masquerade: protoRoute.Masquerade, + KeepRoute: protoRoute.KeepRoute, + SkipAutoApply: protoRoute.SkipAutoApply, } routes = append(routes, convertedRoute) } @@ -1378,7 +1443,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), }, @@ -1584,7 +1649,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { Address: e.config.WgAddr, WGPort: e.config.WgPort, WGPrivKey: e.config.WgPrivateKey.String(), - MTU: iface.DefaultMTU, + MTU: e.config.MTU, TransportNet: transportNet, FilterFn: e.addrViaRoutes, DisableDNS: e.config.DisableDNS, @@ -1643,7 +1708,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) + + dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ + WgInterface: e.wgInterface, + CustomAddress: e.config.CustomDNSAddress, + StatusRecorder: e.statusRecorder, + StateManager: e.stateManager, + DisableSys: e.config.DisableDNS, + }) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 1a179c6ce..aeeb68e79 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -19,22 +19,22 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" @@ -46,9 +46,12 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -86,7 +89,7 @@ type MockWGIface struct { NameFunc func() string AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface - UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error @@ -136,7 +139,7 @@ func (m *MockWGIface) ToInterface() *net.Interface { return m.ToInterfaceFunc() } -func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } @@ -219,14 +222,25 @@ func TestEngine_SSH(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ - WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - ServerSSHAllowed: true, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + engine := NewEngine( + ctx, cancel, + &signal.MockClient{}, + &mgmt.MockClient{}, + relayMgr, + &EngineConfig{ + WgIfaceName: "utun101", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + ServerSSHAllowed: true, + MTU: iface.DefaultMTU, + }, + MobileDependency{}, + peer.NewRecorder("https://mgm"), + nil, + ) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -257,7 +271,7 @@ func TestEngine_SSH(t *testing.T) { }, }, nil } - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) } @@ -355,13 +369,23 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + engine := NewEngine( + ctx, cancel, + &signal.MockClient{}, + &mgmt.MockClient{}, + relayMgr, + &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + MTU: iface.DefaultMTU, + }, + MobileDependency{}, + peer.NewRecorder("https://mgm"), + nil) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -396,7 +420,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) + engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) @@ -573,13 +597,14 @@ func TestEngine_Sync(t *testing.T) { } return nil } - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -593,7 +618,7 @@ func TestEngine_Sync(t *testing.T) { } }() - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) return @@ -737,13 +762,14 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx newNet, err := stdnet.NewNet() if err != nil { @@ -938,13 +964,14 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) + MTU: iface.DefaultMTU, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx newNet, err := stdnet.NewNet() @@ -1048,7 +1075,7 @@ func TestEngine_MultiplePeers(t *testing.T) { defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) wg.Done() @@ -1165,6 +1192,7 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { config: &EngineConfig{ IFaceBlackList: testCase.inputBlacklistInterface, NATExternalIPs: testCase.inputMapList, + MTU: iface.DefaultMTU, }, } parsedList := engine.parseNATExternalIPMappings() @@ -1465,10 +1493,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin WgAddr: resp.PeerConfig.Address, WgPrivateKey: key, WgPort: wgPort, + MTU: iface.DefaultMTU, } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil + e.ctx = ctx return e, err } @@ -1533,7 +1563,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + permissionsManager := permissions.NewManager(store) + peersManager := peers.NewManager(store, permissionsManager) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1550,7 +1584,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri Return(&types.ExtraSettings{}, nil). AnyTimes() - permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index bf96153ea..690fdb7cc 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -9,9 +9,9 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -24,7 +24,7 @@ type wgIfaceBase interface { Name() string Address() wgaddr.Address ToInterface() *net.Interface - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/login.go b/client/internal/login.go index d5412a110..257e3c3ac 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, return false, err } - _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) + _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) if isLoginNeeded(err) { return true, nil } @@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string, return err } - serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) + serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) if serverKey != nil && isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) + if err != nil { + return err + } + } else if err != nil { return err } - return err + return nil } func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { @@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err + return nil, nil, err } sysInfo := system.GetInfo(ctx) @@ -121,8 +125,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockInbound, config.LazyConnectionEnabled, ) - _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, err + loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) + return serverKey, loginResp, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index dbb4747a5..a4ffa3a25 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,7 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const defaultChannelSize = 100 diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index a6cf3cd25..8db9e58f4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -6,12 +6,11 @@ import ( "math/rand" "net" "net/netip" - "os" "runtime" "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -29,10 +28,6 @@ import ( semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) -const ( - defaultWgKeepAlive = 25 * time.Second -) - type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -118,6 +113,8 @@ type Conn struct { // debug purpose dumpState *stateDump + + endpointUpdater *EndpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -130,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), } return conn, nil @@ -174,7 +172,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) - if os.Getenv("NB_FORCE_RELAY") != "true" { + if !isForceRelayed() { conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } @@ -250,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.RemoveWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -376,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn wgProxy.Work() } - if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { + conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) + if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() + if conn.wgProxyRelay != nil { + conn.Log.Debugf("redirect packets from relayed conn to WireGuard") + conn.wgProxyRelay.RedirectAs(ep) + } + conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -410,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { + presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -419,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() { defer conn.wgWatcherWg.Done() conn.workerRelay.EnableWgWatcher(conn.ctx) }() + conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) @@ -478,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { + presharedKey := conn.presharedKey(rci.rosenpassPubKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -546,17 +554,6 @@ func (conn *Conn) onGuardEvent() { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { - presharedKey := conn.presharedKey(remoteRPKey) - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - addr, - presharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -699,10 +696,6 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 000000000..39cb95591 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,105 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + defaultWgKeepAlive = 25 * time.Second + fallbackDelay = 5 * time.Second +) + +type EndpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + // mu protects updateWireGuardPeer and cancelFunc + mu sync.Mutex + cancelFunc func() + updateWg sync.WaitGroup +} + +func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater { + return &EndpointUpdater{ + log: log, + wgConfig: wgConfig, + initiator: initiator, + } +} + +// ConfigureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.initiator { + e.log.Debugf("configure up WireGuard as initiatr") + return e.updateWireGuardPeer(addr, presharedKey) + } + + // prevent to run new update while cancel the previous update + e.waitForCloseTheDelayedUpdate() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + e.updateWg.Add(1) + go e.scheduleDelayedUpdate(ctx, addr, presharedKey) + + e.log.Debugf("configure up WireGuard and wait for handshake") + return e.updateWireGuardPeer(nil, presharedKey) +} + +func (e *EndpointUpdater) RemoveWgPeer() error { + e.mu.Lock() + defer e.mu.Unlock() + + e.waitForCloseTheDelayedUpdate() + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { + if e.cancelFunc == nil { + return + } + + e.cancelFunc() + e.cancelFunc = nil + e.updateWg.Wait() +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) { + defer e.updateWg.Done() + t := time.NewTimer(fallbackDelay) + defer t.Stop() + + select { + case <-ctx.Done(): + return + case <-t.C: + e.mu.Lock() + if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + e.mu.Unlock() + } +} + +func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error { + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + presharedKey, + ) +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go new file mode 100644 index 000000000..32a458d00 --- /dev/null +++ b/client/internal/peer/env.go @@ -0,0 +1,14 @@ +package peer + +import ( + "os" + "strings" +) + +const ( + EnvKeyNBForceRelay = "NB_FORCE_RELAY" +) + +func isForceRelayed() bool { + return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index b9c9aa134..70850e6eb 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 3cbf74cfd..42eaea683 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -43,13 +43,6 @@ type OfferAnswer struct { SessionID *ICESessionID } -func (oa *OfferAnswer) SessionIDString() string { - if oa.SessionID == nil { - return "unknown" - } - return oa.SessionID.String() -} - type Handshaker struct { mu sync.Mutex log *log.Entry @@ -57,7 +50,7 @@ type Handshaker struct { signaler *Signaler ice *WorkerICE relay *WorkerRelay - onNewOfferListeners []func(*OfferAnswer) + onNewOfferListeners []*OfferListener // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - h.onNewOfferListeners = append(h.onNewOfferListeners, offer) + l := NewOfferListener(offer) + h.onNewOfferListeners = append(h.onNewOfferListeners, l) } func (h *Handshaker) Listen(ctx context.Context) { @@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } for _, listener := range h.onNewOfferListeners { - listener(&remoteOfferAnswer) + listener.Notify(&remoteOfferAnswer) } h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) for _, listener := range h.onNewOfferListeners { - listener(&remoteOfferAnswer) + listener.Notify(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go new file mode 100644 index 000000000..e2d3f3f38 --- /dev/null +++ b/client/internal/peer/handshaker_listener.go @@ -0,0 +1,62 @@ +package peer + +import ( + "sync" +) + +type callbackFunc func(remoteOfferAnswer *OfferAnswer) + +func (oa *OfferAnswer) SessionIDString() string { + if oa.SessionID == nil { + return "unknown" + } + return oa.SessionID.String() +} + +type OfferListener struct { + fn callbackFunc + running bool + latest *OfferAnswer + mu sync.Mutex +} + +func NewOfferListener(fn callbackFunc) *OfferListener { + return &OfferListener{ + fn: fn, + } +} + +func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { + o.mu.Lock() + defer o.mu.Unlock() + + // Store the latest offer + o.latest = remoteOfferAnswer + + // If already running, the running goroutine will pick up this latest value + if o.running { + return + } + + // Start processing + o.running = true + + // Process in a goroutine to avoid blocking the caller + go func(remoteOfferAnswer *OfferAnswer) { + for { + o.fn(remoteOfferAnswer) + + o.mu.Lock() + if o.latest == nil { + // No more work to do + o.running = false + o.mu.Unlock() + return + } + remoteOfferAnswer = o.latest + // Clear the latest to mark it as being processed + o.latest = nil + o.mu.Unlock() + } + }(remoteOfferAnswer) +} diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go new file mode 100644 index 000000000..8363741a5 --- /dev/null +++ b/client/internal/peer/handshaker_listener_test.go @@ -0,0 +1,39 @@ +package peer + +import ( + "testing" + "time" +) + +func Test_newOfferListener(t *testing.T) { + dummyOfferAnswer := &OfferAnswer{} + runChan := make(chan struct{}, 10) + + longRunningFn := func(remoteOfferAnswer *OfferAnswer) { + time.Sleep(1 * time.Second) + runChan <- struct{}{} + } + + hl := NewOfferListener(longRunningFn) + + hl.Notify(dummyOfferAnswer) + hl.Notify(dummyOfferAnswer) + hl.Notify(dummyOfferAnswer) + + // Wait for exactly 2 callbacks + for i := 0; i < 2; i++ { + select { + case <-runChan: + case <-time.After(3 * time.Second): + t.Fatal("Timeout waiting for callback") + } + } + + // Verify no additional callbacks happen + select { + case <-runChan: + t.Fatal("Unexpected additional callback") + case <-time.After(100 * time.Millisecond): + t.Log("Correctly received exactly 2 callbacks") + } +} diff --git a/client/internal/peer/ice/StunTurn.go b/client/internal/peer/ice/StunTurn.go index 63ee8c713..a389f5444 100644 --- a/client/internal/peer/ice/StunTurn.go +++ b/client/internal/peer/ice/StunTurn.go @@ -3,7 +3,7 @@ package ice import ( "sync/atomic" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" ) type StunTurn atomic.Value diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 4a0228405..e80c98884 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,9 +1,10 @@ package ice import ( + "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/logging" "github.com/pion/randutil" log "github.com/sirupsen/logrus" @@ -23,7 +24,20 @@ const ( iceRelayAcceptanceMinWaitDefault = 2 * time.Second ) -func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { +type ThreadSafeAgent struct { + *ice.Agent + once sync.Once +} + +func (a *ThreadSafeAgent) Close() error { + var err error + a.once.Do(func() { + err = a.Agent.Close() + }) + return err +} + +func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() iceFailedTimeout := iceFailedTimeout() @@ -61,7 +75,12 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} } - return ice.NewAgent(agentConfig) + agent, err := ice.NewAgent(agentConfig) + if err != nil { + return nil, err + } + + return &ThreadSafeAgent{Agent: agent}, nil } func GenerateICECredentials() (string, string, error) { diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go index dd854a605..dd5d67403 100644 --- a/client/internal/peer/ice/config.go +++ b/client/internal/peer/ice/config.go @@ -1,7 +1,7 @@ package ice import ( - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" ) type Config struct { diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index ca1d421a5..b28906625 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -1,7 +1,7 @@ package peer import ( - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 218872c15..0ed200fda 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,9 +30,10 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + enabledTime time.Time } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() + w.enabledTime = time.Now() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") @@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex onDisconnectedFn() return } + if lastHandshake.IsZero() { + elapsed := handshake.Sub(w.enabledTime).Seconds() + w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + } + lastHandshake = *handshake resetTime := time.Until(handshake.Add(checkPeriod)) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index ee85254fb..eb886a4d3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -8,12 +8,11 @@ import ( "sync" "time" - "github.com/pion/ice/v3" - "github.com/pion/stun/v2" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -42,7 +41,7 @@ type WorkerICE struct { statusRecorder *Status hasRelayOnLocally bool - agent *ice.Agent + agent *icemaker.ThreadSafeAgent agentDialerCancel context.CancelFunc agentConnecting bool // while it is true, drop all incoming offers lastSuccess time.Time // with this avoid the too frequent ICE agent recreation @@ -55,10 +54,6 @@ type WorkerICE struct { sessionID ICESessionID muxAgent sync.Mutex - StunTurn []*stun.URI - - sentExtraSrflx bool - localUfrag string localPwd string @@ -121,7 +116,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } - // todo consider to switch to Relay connection while establishing a new ICE connection + w.agent = nil } var preferredCandidateTypes []ice.CandidateType @@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.muxAgent.Unlock() return } - w.sentExtraSrflx = false w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true @@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA w.log.Errorf("error while handling remote candidate") return } + + if shouldAddExtraCandidate(candidate) { + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + + if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil { + w.log.Errorf("error while handling remote candidate") + return + } + } } func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { @@ -195,7 +204,7 @@ func (w *WorkerICE) Close() { w.agent = nil } -func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { +func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) @@ -209,14 +218,12 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { + if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) { + w.onICESelectedCandidatePair(agent, c1, c2) + }); err != nil { return nil, err } - if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil { - return nil, fmt.Errorf("failed setting binding response callback: %w", err) - } - return agent, nil } @@ -230,7 +237,7 @@ func (w *WorkerICE) SessionID() ICESessionID { // will block until connection succeeded // but it won't release if ICE Agent went into Disconnected or Failed state, // so we have to cancel it with the provided context once agent detected a broken connection -func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAnswer *OfferAnswer) { +func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) { w.log.Debugf("gather candidates") if err := agent.GatherCandidates(); err != nil { w.log.Warnf("failed to gather candidates: %s", err) @@ -239,7 +246,7 @@ func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAn } w.log.Debugf("turn agent dial") - remoteConn, err := w.turnAgentDial(ctx, remoteOfferAnswer) + remoteConn, err := w.turnAgentDial(ctx, agent, remoteOfferAnswer) if err != nil { w.log.Debugf("failed to dial the remote peer: %s", err) w.closeAgent(agent, w.agentDialerCancel) @@ -252,6 +259,11 @@ func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAn w.closeAgent(agent, w.agentDialerCancel) return } + if pair == nil { + w.log.Warnf("selected candidate pair is nil, cannot proceed") + w.closeAgent(agent, w.agentDialerCancel) + return + } if !isRelayCandidate(pair.Local) { // dynamically set remote WireGuard port if other side specified a different one from the default one @@ -290,13 +302,14 @@ func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAn w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) } -func (w *WorkerICE) closeAgent(agent *ice.Agent, cancel context.CancelFunc) { +func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) { cancel() if err := agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } w.muxAgent.Lock() + // todo review does it make sense to generate new session ID all the time when w.agent==agent sessionID, err := NewICESessionID() if err != nil { w.log.Errorf("failed to create new session ID: %s", err) @@ -325,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) return } - mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault) if !ok { w.log.Warn("invalid udp mux conversion") return @@ -352,41 +365,36 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() - - if !w.shouldSendExtraSrflxCandidate(candidate) { - return - } - - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - w.log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - w.sentExtraSrflx = true - - go func() { - err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) - if err != nil { - w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) - } - }() } -func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { +func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) + + pairStat, ok := agent.GetSelectedCandidatePairStats() + if !ok { + w.log.Warnf("failed to get selected candidate pair stats") + return + } + + duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second)) + if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { + w.log.Debugf("failed to update latency for peer: %s", err) + return + } } -func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { +func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { return func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) switch state { case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected return - case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: + // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to + // notify the conn.onICEStateDisconnected changes to update the current used priority + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() @@ -398,32 +406,34 @@ func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel conte } } -func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) { - if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil { - w.log.Debugf("failed to update latency for peer: %s", err) - return - } -} - -func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { - if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { - return true - } - return false -} - -func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { - isControlling := w.config.LocalKey > w.config.Key - if isControlling { +func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { + if isController(w.config) { return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { - return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } } +func shouldAddExtraCandidate(candidate ice.Candidate) bool { + if candidate.Type() != ice.CandidateTypeServerReflexive { + return false + } + + if candidate.Port() == candidate.RelatedAddress().Port { + return false + } + + // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates + // in newer version we generate locally the extra candidate + if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok { + return false + } + return true +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() - return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ Network: candidate.NetworkType().String(), Address: candidate.Address(), Port: relatedAdd.Port, @@ -431,6 +441,21 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive RelAddr: relatedAdd.Address, RelPort: relatedAdd.Port, }) + if err != nil { + return nil, err + } + + for _, e := range candidate.Extensions() { + // overwrite the original candidate ID with the new one to avoid candidate duplication + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = candidate.ID() + } + if err := ec.AddExtension(e); err != nil { + return nil, err + } + } + + return ec, nil } func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 6bbdbd984..4e6b422f6 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -75,6 +75,8 @@ type ConfigInput struct { DNSLabels domain.List LazyConnectionEnabled *bool + + MTU *uint16 } // Config Configuration type @@ -141,6 +143,8 @@ type Config struct { ClientCertKeyPair *tls.Certificate `json:"-"` LazyConnectionEnabled bool + + MTU uint16 } var ConfigDirOverride string @@ -493,6 +497,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } + if input.MTU != nil && *input.MTU != config.MTU { + log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU) + config.MTU = *input.MTU + updated = true + } else if config.MTU == 0 { + config.MTU = iface.DefaultMTU + log.Infof("using default MTU %d", config.MTU) + updated = true + } + return updated, nil } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 6e1f83a9a..fa208716f 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -7,12 +7,12 @@ import ( "sync" "time" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ProbeResult holds the info about the result of a relay probe request diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index ba27df654..9069cdcc5 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -2,11 +2,13 @@ package dnsinterceptor import ( "context" + "errors" "fmt" "net/netip" "runtime" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -26,6 +28,8 @@ import ( "github.com/netbirdio/netbird/route" ) +const dnsTimeout = 8 * time.Second + type domainMap map[domain.Domain][]netip.Prefix type internalDNATer interface { @@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) if err != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return @@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + startTime := time.Now() + reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if err != nil { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + if errors.Is(err, context.DeadlineExceeded) { + elapsed := time.Since(startTime) + peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) + logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", + elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) + } else { + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + } if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { logger.Errorf("failed writing DNS response: %v", err) } @@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR } return } + +func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { + if d.statusRecorder == nil { + return "" + } + + peerState, err := d.statusRecorder.GetPeer(peerKey) + if err != nil { + return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err) + } + + return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState)) +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index da5534902..04513bbe4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,9 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" - relayClient "github.com/netbirdio/netbird/shared/relay/client" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/route" - nbnet "github.com/netbirdio/netbird/util/net" + relayClient "github.com/netbirdio/netbird/shared/relay/client" "github.com/netbirdio/netbird/version" ) @@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager { notifier := notifier.NewNotifier() sysOps := systemops.NewSysOps(config.WGInterface, notifier) + if runtime.GOOS == "windows" && config.WGInterface != nil { + nbnet.SetVPNInterfaceName(config.WGInterface.Name()) + } + dm := &DefaultManager{ ctx: mCTX, stop: cancel, @@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error { return nil } - if err := m.sysOps.CleanupRouting(nil); err != nil { + if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error { ips := resolveURLsToIPs(initialAddresses) - if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil { return fmt.Errorf("setup routing: %w", err) } @@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { - if err := m.sysOps.CleanupRouting(stateManager); err != nil { + if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") } + + if runtime.GOOS == "windows" { + nbnet.SetVPNInterfaceName("") + } } m.mux.Lock() @@ -368,7 +376,11 @@ func (m *DefaultManager) UpdateRoutes( var merr *multierror.Error if !m.disableClientRoutes { - filteredClientRoutes := m.routeSelector.FilterSelected(clientRoutes) + + // Update route selector based on management server's isSelected status + m.updateRouteSelectorFromManagement(clientRoutes) + + filteredClientRoutes := m.routeSelector.FilterSelectedExitNodes(clientRoutes) if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err)) @@ -430,7 +442,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.mux.Lock() defer m.mux.Unlock() - networks = m.routeSelector.FilterSelected(networks) + networks = m.routeSelector.FilterSelectedExitNodes(networks) m.notifier.OnNewRoutes(networks) @@ -583,3 +595,106 @@ func resolveURLsToIPs(urls []string) []net.IP { } return ips } + +// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server +func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) { + exitNodeInfo := m.collectExitNodeInfo(clientRoutes) + if len(exitNodeInfo.allIDs) == 0 { + return + } + + m.updateExitNodeSelections(exitNodeInfo) + m.logExitNodeUpdate(exitNodeInfo) +} + +type exitNodeInfo struct { + allIDs []route.NetID + selectedByManagement []route.NetID + userSelected []route.NetID + userDeselected []route.NetID +} + +func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo { + var info exitNodeInfo + + for haID, routes := range clientRoutes { + if !m.isExitNodeRoute(routes) { + continue + } + + netID := haID.NetID() + info.allIDs = append(info.allIDs, netID) + + if m.routeSelector.HasUserSelectionForRoute(netID) { + m.categorizeUserSelection(netID, &info) + } else { + m.checkManagementSelection(routes, netID, &info) + } + } + + return info +} + +func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool { + return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR +} + +func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) { + if m.routeSelector.IsSelected(netID) { + info.userSelected = append(info.userSelected, netID) + } else { + info.userDeselected = append(info.userDeselected, netID) + } +} + +func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID route.NetID, info *exitNodeInfo) { + for _, route := range routes { + if !route.SkipAutoApply { + info.selectedByManagement = append(info.selectedByManagement, netID) + break + } + } +} + +func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) { + routesToDeselect := m.getRoutesToDeselect(info.allIDs) + m.deselectExitNodes(routesToDeselect) + m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs) +} + +func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID { + var routesToDeselect []route.NetID + for _, netID := range allIDs { + if !m.routeSelector.HasUserSelectionForRoute(netID) { + routesToDeselect = append(routesToDeselect, netID) + } + } + return routesToDeselect +} + +func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) { + if len(routesToDeselect) == 0 { + return + } + + err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect) + if err != nil { + log.Warnf("Failed to deselect exit nodes: %v", err) + } +} + +func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) { + if len(selectedByManagement) == 0 { + return + } + + err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs) + if err != nil { + log.Warnf("Failed to select exit nodes: %v", err) + } +} + +func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) { + log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected", + len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected)) +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2f13c2134..d2f02526c 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -190,14 +190,15 @@ func TestManagerUpdateRoutes(t *testing.T) { name: "No Small Client Route Should Be Added", inputRoutes: []*route.Route{ { - ID: "a", - NetID: "routeA", - Peer: remotePeerKey1, - Network: netip.MustParsePrefix("0.0.0.0/0"), - NetworkType: route.IPv4Network, - Metric: 9999, - Masquerade: false, - Enabled: true, + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("0.0.0.0/0"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + SkipAutoApply: false, }, }, inputSerial: 1, diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index a375ce832..7cb8dae93 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -12,11 +12,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 128afa2a5..26a548634 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -3,7 +3,6 @@ package systemops import ( - "context" "errors" "fmt" "net" @@ -22,7 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net/hooks" ) const localSubnetsCacheTTL = 15 * time.Minute @@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() + hooks.RemoveWriteHooks() + hooks.RemoveCloseHooks() + hooks.RemoveAddressRemoveHooks() if err := r.refCounter.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) @@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := util.GetPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - + beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } @@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } - afterHook := func(connID nbnet.ConnectionID) error { + afterHook := func(connID hooks.ConnectionID) error { if err := r.refCounter.DecrementWithID(string(connID)); err != nil { return fmt.Errorf("remove route reference: %w", err) } @@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M var merr *multierror.Error for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err)) + continue + } + if err := beforeHook("init", prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err)) } } - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } + hooks.AddWriteHook(beforeHook) + hooks.AddCloseHook(afterHook) - var merr *multierror.Error - for _, ip := range resolvedIPs { - merr = multierror.Append(merr, beforeHook(connID, ip.IP)) - } - return nberrors.FormatErrorOrNil(merr) - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.Decrement(prefix); err != nil { return fmt.Errorf("remove route reference: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index c1c1182bc..32ea38a7a 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" ) type dialer interface { @@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) intf, err := net.InterfaceByName(wgInterface.Name()) @@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 10356eae0..99a363371 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -12,14 +12,14 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index c0cef94ba..bd10f131f 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // IPRule contains IP rule information for debugging @@ -94,15 +94,15 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { - if !nbnet.AdvancedRouting() { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) { + if !advancedRouting { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } defer func() { if err != nil { - if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { - if !nbnet.AdvancedRouting() { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if !advancedRouting { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index f165f7779..d43c2d5bf 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -20,11 +20,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index ad37f611f..959c697e4 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type PacketExpectation struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 36e714ec4..7bce6af80 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "runtime/debug" + "sort" "strconv" "sync" "syscall" @@ -19,9 +20,16 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" ) -const InfiniteLifetime = 0xffffffff +func init() { + nbnet.GetBestInterfaceFunc = GetBestInterface +} + +const ( + InfiniteLifetime = 0xffffffff +) type RouteUpdateType int @@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct { Table [1]MIB_IPFORWARD_ROW2 // Flexible array member } +// candidateRoute represents a potential route for selection during route lookup +type candidateRoute struct { + interfaceIndex uint32 + prefixLength uint8 + routeMetric uint32 + interfaceMetric int +} + // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -177,11 +193,20 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + return r.cleanupRefCounter(stateManager) } @@ -336,7 +361,7 @@ func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error { if e1 != 0 { return fmt.Errorf("CreateIpForwardEntry2: %w", e1) } - return fmt.Errorf("CreateIpForwardEntry2: code %d", r1) + return fmt.Errorf("CreateIpForwardEntry2: code %d", windows.NTStatus(r1)) } return nil } @@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { if table != nil { - ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) - if ret != 0 { - log.Warnf("FreeMibTable failed with return code: %d", ret) - } + _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) } } @@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { entryPtr := basePtr + uintptr(i)*entrySize entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - detailed := buildWindowsDetailedRoute(entry) - if detailed != nil { + if detailed := buildWindowsDetailedRoute(entry); detailed != nil { detailedRoutes = append(detailedRoutes, *detailed) } } @@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { return ip } +// parseCandidatesFromTable extracts all matching candidate routes from the routing table +func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute { + var candidates []candidateRoute + entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) + basePtr := uintptr(unsafe.Pointer(&table.Table[0])) + + for i := uint32(0); i < table.NumEntries; i++ { + entryPtr := basePtr + uintptr(i)*entrySize + entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) + + if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil { + candidates = append(candidates, *candidate) + } + } + + return candidates +} + +// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry +// Returns nil if the route doesn't match the destination or should be skipped +func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute { + if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex { + return nil + } + + destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) + if !destPrefix.IsValid() || !destPrefix.Contains(dest) { + return nil + } + + interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family) + + return &candidateRoute{ + interfaceIndex: entry.InterfaceIndex, + prefixLength: entry.DestinationPrefix.PrefixLength, + routeMetric: entry.Metric, + interfaceMetric: interfaceMetric, + } +} + // getInterfaceMetric retrieves the interface metric for a given interface and address family func getInterfaceMetric(interfaceIndex uint32, family int16) int { if interfaceIndex == 0 { @@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int { return int(ipInterfaceRow.Metric) } +// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric +func sortRouteCandidates(candidates []candidateRoute) { + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].prefixLength != candidates[j].prefixLength { + return candidates[i].prefixLength > candidates[j].prefixLength + } + if candidates[i].routeMetric != candidates[j].routeMetric { + return candidates[i].routeMetric < candidates[j].routeMetric + } + return candidates[i].interfaceMetric < candidates[j].interfaceMetric + }) +} + +// GetBestInterface finds the best interface for reaching a destination, +// excluding the VPN interface to avoid routing loops. +// +// Route selection priority: +// 1. Longest prefix match (most specific route) +// 2. Lowest route metric +// 3. Lowest interface metric +func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { + var skipInterfaceIndex int + if vpnIntf != "" { + if iface, err := net.InterfaceByName(vpnIntf); err == nil { + skipInterfaceIndex = iface.Index + } else { + // not critical, if we cannot get ahold of the interface then we won't need to skip it + log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) + } + } + + table, err := getWindowsRoutingTable() + if err != nil { + return nil, fmt.Errorf("get routing table: %w", err) + } + defer freeWindowsRoutingTable(table) + + candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex) + + if len(candidates) == 0 { + return nil, fmt.Errorf("no route to %s", dest) + } + + // Sort routes: prefix length -> route metric -> interface metric + sortRouteCandidates(candidates) + + for _, candidate := range candidates { + iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex)) + if err != nil { + log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err) + continue + } + + if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() { + continue + } + + if iface.Flags&net.FlagUp == 0 { + log.Debugf("interface %s is down, trying next route", iface.Name) + continue + } + + log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d", + dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric) + return iface, nil + } + + return nil, fmt.Errorf("no usable interface found for %s", dest) +} + // formatRouteAge formats the route age in seconds to a human-readable string func formatRouteAge(ageSeconds uint32) string { if ageSeconds == 0 { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 523bd0b0d..3561adec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) var ( diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go index ac5a48e37..57ea32f69 100644 --- a/client/internal/routemanager/util/ip.go +++ b/client/internal/routemanager/util/ip.go @@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { if !ok { return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) } + addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) + prefix := netip.PrefixFrom(addr, addr.BitLen()) return prefix, nil } diff --git a/client/internal/routemanager/vars/vars.go b/client/internal/routemanager/vars/vars.go index 4aa986d2f..ac11dec8c 100644 --- a/client/internal/routemanager/vars/vars.go +++ b/client/internal/routemanager/vars/vars.go @@ -13,4 +13,6 @@ var ( Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) + + ExitNodeCIDR = "0.0.0.0/0" ) diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 8ebdc63e5..e4a78599e 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,19 +9,27 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) +const ( + exitNodeCIDR = "0.0.0.0/0" +) + type RouteSelector struct { mu sync.RWMutex deselectedRoutes map[route.NetID]struct{} + selectedRoutes map[route.NetID]struct{} deselectAll bool } func NewRouteSelector() *RouteSelector { return &RouteSelector{ deselectedRoutes: map[route.NetID]struct{}{}, + selectedRoutes: map[route.NetID]struct{}{}, deselectAll: false, } } @@ -32,7 +40,14 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al defer rs.mu.Unlock() if !appendRoute || rs.deselectAll { + if rs.deselectedRoutes == nil { + rs.deselectedRoutes = map[route.NetID]struct{}{} + } + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } maps.Clear(rs.deselectedRoutes) + maps.Clear(rs.selectedRoutes) for _, r := range allRoutes { rs.deselectedRoutes[r] = struct{}{} } @@ -45,6 +60,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al continue } delete(rs.deselectedRoutes, route) + rs.selectedRoutes[route] = struct{}{} } rs.deselectAll = false @@ -58,7 +74,14 @@ func (rs *RouteSelector) SelectAllRoutes() { defer rs.mu.Unlock() rs.deselectAll = false + if rs.deselectedRoutes == nil { + rs.deselectedRoutes = map[route.NetID]struct{}{} + } + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } maps.Clear(rs.deselectedRoutes) + maps.Clear(rs.selectedRoutes) } // DeselectRoutes removes specific routes from the selection. @@ -77,6 +100,7 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. continue } rs.deselectedRoutes[route] = struct{}{} + delete(rs.selectedRoutes, route) } return errors.FormatErrorOrNil(err) @@ -88,7 +112,14 @@ func (rs *RouteSelector) DeselectAllRoutes() { defer rs.mu.Unlock() rs.deselectAll = true + if rs.deselectedRoutes == nil { + rs.deselectedRoutes = map[route.NetID]struct{}{} + } + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } maps.Clear(rs.deselectedRoutes) + maps.Clear(rs.selectedRoutes) } // IsSelected checks if a specific route is selected. @@ -97,11 +128,14 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { + log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] - return !deselected + isSelected := !deselected + log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) + return isSelected } // FilterSelected removes unselected routes from the provided map. @@ -124,15 +158,98 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { return filtered } +// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route +func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool { + rs.mu.RLock() + defer rs.mu.RUnlock() + + _, selected := rs.selectedRoutes[routeID] + _, deselected := rs.deselectedRoutes[routeID] + return selected || deselected +} + +func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap { + rs.mu.RLock() + defer rs.mu.RUnlock() + + if rs.deselectAll { + return route.HAMap{} + } + + filtered := make(route.HAMap, len(routes)) + for id, rt := range routes { + netID := id.NetID() + if rs.isDeselected(netID) { + continue + } + + if !isExitNode(rt) { + filtered[id] = rt + continue + } + + rs.applyExitNodeFilter(id, netID, rt, filtered) + } + + return filtered +} + +func (rs *RouteSelector) isDeselected(netID route.NetID) bool { + _, deselected := rs.deselectedRoutes[netID] + return deselected || rs.deselectAll +} + +func isExitNode(rt []*route.Route) bool { + return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR +} + +func (rs *RouteSelector) applyExitNodeFilter( + id route.HAUniqueID, + netID route.NetID, + rt []*route.Route, + out route.HAMap, +) { + + if rs.hasUserSelections() { + // user made explicit selects/deselects + if rs.IsSelected(netID) { + out[id] = rt + } + return + } + + // no explicit selections: only include routes marked !SkipAutoApply (=AutoApply) + sel := collectSelected(rt) + if len(sel) > 0 { + out[id] = sel + } +} + +func (rs *RouteSelector) hasUserSelections() bool { + return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0 +} + +func collectSelected(rt []*route.Route) []*route.Route { + var sel []*route.Route + for _, r := range rt { + if !r.SkipAutoApply { + sel = append(sel, r) + } + } + return sel +} + // MarshalJSON implements the json.Marshaler interface func (rs *RouteSelector) MarshalJSON() ([]byte, error) { rs.mu.RLock() defer rs.mu.RUnlock() return json.Marshal(struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` DeselectAll bool `json:"deselect_all"` }{ + SelectedRoutes: rs.selectedRoutes, DeselectedRoutes: rs.deselectedRoutes, DeselectAll: rs.deselectAll, }) @@ -147,11 +264,13 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { // Check for null or empty JSON if len(data) == 0 || string(data) == "null" { rs.deselectedRoutes = map[route.NetID]struct{}{} + rs.selectedRoutes = map[route.NetID]struct{}{} rs.deselectAll = false return nil } var temp struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` DeselectAll bool `json:"deselect_all"` } @@ -160,12 +279,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { return err } + rs.selectedRoutes = temp.SelectedRoutes rs.deselectedRoutes = temp.DeselectedRoutes rs.deselectAll = temp.DeselectAll if rs.deselectedRoutes == nil { rs.deselectedRoutes = map[route.NetID]struct{}{} } + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } return nil } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index cfa723246..5faea2456 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -1,6 +1,7 @@ package routeselector_test import ( + "net/netip" "slices" "testing" @@ -273,6 +274,62 @@ func TestRouteSelector_FilterSelected(t *testing.T) { }, filtered) } +func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) { + rs := routeselector.NewRouteSelector() + + // Create test routes + exitNode1 := &route.Route{ + ID: "route1", + NetID: "net1", + Network: netip.MustParsePrefix("0.0.0.0/0"), + Peer: "peer1", + SkipAutoApply: false, + } + exitNode2 := &route.Route{ + ID: "route2", + NetID: "net1", + Network: netip.MustParsePrefix("0.0.0.0/0"), + Peer: "peer2", + SkipAutoApply: true, + } + normalRoute := &route.Route{ + ID: "route3", + NetID: "net2", + Network: netip.MustParsePrefix("192.168.1.0/24"), + Peer: "peer3", + SkipAutoApply: false, + } + + routes := route.HAMap{ + "net1|0.0.0.0/0": {exitNode1, exitNode2}, + "net2|192.168.1.0/24": {normalRoute}, + } + + // Test filtering + filtered := rs.FilterSelectedExitNodes(routes) + + // Should only include selected exit nodes and all normal routes + assert.Len(t, filtered, 2) + assert.Len(t, filtered["net1|0.0.0.0/0"], 1) // Only the selected exit node + assert.Equal(t, exitNode1.ID, filtered["net1|0.0.0.0/0"][0].ID) + assert.Len(t, filtered["net2|192.168.1.0/24"], 1) // Normal route should be included + assert.Equal(t, normalRoute.ID, filtered["net2|192.168.1.0/24"][0].ID) + + // Test with deselected routes + err := rs.DeselectRoutes([]route.NetID{"net1"}, []route.NetID{"net1", "net2"}) + assert.NoError(t, err) + filtered = rs.FilterSelectedExitNodes(routes) + assert.Len(t, filtered, 1) // Only normal route should remain + assert.Len(t, filtered["net2|192.168.1.0/24"], 1) + assert.Equal(t, normalRoute.ID, filtered["net2|192.168.1.0/24"][0].ID) + + // Test with deselect all + rs = routeselector.NewRouteSelector() + rs.DeselectAllRoutes() + filtered = rs.FilterSelectedExitNodes(routes) + assert.Len(t, filtered, 0) // No routes should be selected +} + func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialRoutes := []route.NetID{"route1", "route2", "route3"} newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index e80adb42b..8961eaa69 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -5,7 +5,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // Dial connects to the address on the named network. diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index 9ce0a5556..d3be1896f 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -6,7 +6,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ListenPacket listens for incoming packets on the given network and address. diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index aa9fdd045..4b031c05c 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" ) @@ -32,9 +33,15 @@ type Net struct { // NewNetWithDiscover creates a new StdNet instance. func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { n := &Net{ - iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover), interfaceFilter: InterfaceFilter(disallowList), } + // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client + // so in android cli use pionDiscover + if netstack.IsEnabled() { + n.iFaceDiscover = pionDiscover{} + } else { + n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover) + } return n, n.UpdateInterfaces() } diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go new file mode 100644 index 000000000..78d70c15b --- /dev/null +++ b/client/internal/wg_iface_monitor.go @@ -0,0 +1,98 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine +// if the interface is deleted externally while the engine is running. +type WGIfaceMonitor struct { + done chan struct{} +} + +// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. +func NewWGIfaceMonitor() *WGIfaceMonitor { + return &WGIfaceMonitor{ + done: make(chan struct{}), + } +} + +// Start begins monitoring the WireGuard interface. +// It relies on the provided context cancellation to stop. +func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { + defer close(m.done) + + // Skip on mobile platforms as they handle interface lifecycle differently + if runtime.GOOS == "android" || runtime.GOOS == "ios" { + log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) + return false, errors.New("not supported on mobile platforms") + } + + if ifaceName == "" { + log.Debugf("Interface monitor: empty interface name, skipping monitor") + return false, errors.New("empty interface name") + } + + // Get initial interface index to track the specific interface instance + expectedIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) + return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) + } + + log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } + +} + +// getInterfaceIndex returns the index of a network interface by name. +// Returns an error if the interface is not found. +func getInterfaceIndex(name string) (int, error) { + if name == "" { + return 0, fmt.Errorf("empty interface name") + } + ifi, err := net.InterfaceByName(name) + if err != nil { + // Check if it's specifically a "not found" error + if errors.Is(err, &net.OpError{}) { + // On some systems, this might be a "not found" error + return 0, fmt.Errorf("interface not found: %w", err) + } + return 0, fmt.Errorf("failed to lookup interface: %w", err) + } + if ifi == nil { + return 0, fmt.Errorf("interface not found") + } + return ifi.Index, nil +} diff --git a/client/net/conn.go b/client/net/conn.go new file mode 100644 index 000000000..918e7f628 --- /dev/null +++ b/client/net/conn.go @@ -0,0 +1,49 @@ +//go:build !ios + +package net + +import ( + "io" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/net/hooks" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID hooks.ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +func (c *Conn) Close() error { + return closeConn(c.ID, c.Conn) +} + +// TCPConn wraps net.TCPConn to override its Close method to include hook functionality. +type TCPConn struct { + *net.TCPConn + ID hooks.ConnectionID +} + +// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +func (c *TCPConn) Close() error { + return closeConn(c.ID, c.TCPConn) +} + +// closeConn is a helper function to close connections and execute close hooks. +func closeConn(id hooks.ConnectionID, conn io.Closer) error { + err := conn.Close() + + closeHooks := hooks.GetCloseHooks() + for _, hook := range closeHooks { + if err := hook(id); err != nil { + log.Errorf("Error executing close hook: %v", err) + } + } + + return err +} diff --git a/client/net/dial.go b/client/net/dial.go new file mode 100644 index 000000000..041a00e5d --- /dev/null +++ b/client/net/dial.go @@ -0,0 +1,82 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + udpConn, ok := c.Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.TCPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + tcpConn, ok := c.Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn) + } + return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/dial_ios.go b/client/net/dial_ios.go similarity index 100% rename from util/net/dial_ios.go rename to client/net/dial_ios.go diff --git a/util/net/dialer.go b/client/net/dialer.go similarity index 99% rename from util/net/dialer.go rename to client/net/dialer.go index 0786c667e..29bec05a7 100644 --- a/util/net/dialer.go +++ b/client/net/dialer.go @@ -16,6 +16,5 @@ func NewDialer() *Dialer { Dialer: &net.Dialer{}, } dialer.init() - return dialer } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go new file mode 100644 index 000000000..2e1eb53d8 --- /dev/null +++ b/client/net/dialer_dial.go @@ -0,0 +1,87 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + + if CustomRoutingDisabled() || AdvancedRouting() { + return d.Dialer.DialContext(ctx, network, address) + } + + connID := hooks.GenerateConnID() + if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error { + if ctx.Err() != nil { + return ctx.Err() + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + + resolver := customResolver + if resolver == nil { + resolver = net.DefaultResolver + } + + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var merr *multierror.Error + for _, ip := range ips { + prefix, err := util.GetPrefixFromIP(ip.IP) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err)) + continue + } + for _, hook := range writeHooks { + if err := hook(connID, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err)) + } + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/dialer_init_android.go b/client/net/dialer_init_android.go similarity index 100% rename from util/net/dialer_init_android.go rename to client/net/dialer_init_android.go diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go new file mode 100644 index 000000000..18ebc6ad1 --- /dev/null +++ b/client/net/dialer_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (d *Dialer) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/dialer_init_linux.go b/client/net/dialer_init_linux.go similarity index 100% rename from util/net/dialer_init_linux.go rename to client/net/dialer_init_linux.go diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go new file mode 100644 index 000000000..6eefe5b1e --- /dev/null +++ b/client/net/dialer_init_windows.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyUnicastIFToSocket +} diff --git a/util/net/env.go b/client/net/env.go similarity index 94% rename from util/net/env.go rename to client/net/env.go index 32425665d..8f326ca88 100644 --- a/util/net/env.go +++ b/client/net/env.go @@ -11,6 +11,7 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) // CustomRoutingDisabled returns true if custom routing is disabled. diff --git a/client/net/env_android.go b/client/net/env_android.go new file mode 100644 index 000000000..9d89951a1 --- /dev/null +++ b/client/net/env_android.go @@ -0,0 +1,24 @@ +//go:build android + +package net + +// Init initializes the network environment for Android +func Init() { + // No initialization needed on Android +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on Android since we cannot handle routes dynamically. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on Android +func SetVPNInterfaceName(name string) { + // No-op on Android - not needed for Android VPN service +} + +// GetVPNInterfaceName returns empty string on Android +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_generic.go b/client/net/env_generic.go new file mode 100644 index 000000000..f467930c3 --- /dev/null +++ b/client/net/env_generic.go @@ -0,0 +1,23 @@ +//go:build !linux && !windows && !android + +package net + +// Init initializes the network environment (no-op on non-Linux/Windows platforms) +func Init() { + // No-op on non-Linux/Windows platforms +} + +// AdvancedRouting returns false on non-Linux/Windows platforms +func AdvancedRouting() bool { + return false +} + +// SetVPNInterfaceName is a no-op on non-Windows platforms +func SetVPNInterfaceName(name string) { + // No-op on non-Windows platforms +} + +// GetVPNInterfaceName returns empty string on non-Windows platforms +func GetVPNInterfaceName() string { + return "" +} diff --git a/util/net/env_linux.go b/client/net/env_linux.go similarity index 86% rename from util/net/env_linux.go rename to client/net/env_linux.go index 3159f6462..82d9a74a8 100644 --- a/util/net/env_linux.go +++ b/client/net/env_linux.go @@ -17,8 +17,7 @@ import ( const ( // these have the same effect, skip socket env supported for backward compatibility - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" - envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" ) var advancedRoutingSupported bool @@ -27,6 +26,7 @@ func Init() { advancedRoutingSupported = checkAdvancedRoutingSupport() } +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes func AdvancedRouting() bool { return advancedRoutingSupported } @@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool { } func CheckFwmarkSupport() bool { - // temporarily enable advanced routing to check fwmarks are supported + // temporarily enable advanced routing to check if fwmarks are supported old := advancedRoutingSupported advancedRoutingSupported = true defer func() { @@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool { } return true } + +// SetVPNInterfaceName is a no-op on Linux +func SetVPNInterfaceName(name string) { + // No-op on Linux - not needed for fwmark-based routing +} + +// GetVPNInterfaceName returns empty string on Linux +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_windows.go b/client/net/env_windows.go new file mode 100644 index 000000000..7e8868ba5 --- /dev/null +++ b/client/net/env_windows.go @@ -0,0 +1,67 @@ +//go:build windows + +package net + +import ( + "os" + "strconv" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +var ( + vpnInterfaceName string + vpnInitMutex sync.RWMutex + + advancedRoutingSupported bool +) + +func Init() { + advancedRoutingSupported = checkAdvancedRoutingSupport() +} + +func checkAdvancedRoutingSupport() bool { + var err error + var legacyRouting bool + if val := os.Getenv(envUseLegacyRouting); val != "" { + legacyRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + } + } + + if legacyRouting || netstack.IsEnabled() { + log.Info("advanced routing has been requested to be disabled") + return false + } + + log.Info("system supports advanced routing") + + return true +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes +func AdvancedRouting() bool { + return advancedRoutingSupported +} + +// GetVPNInterfaceName returns the stored VPN interface name +func GetVPNInterfaceName() string { + vpnInitMutex.RLock() + defer vpnInitMutex.RUnlock() + return vpnInterfaceName +} + +// SetVPNInterfaceName sets the VPN interface name for lazy initialization +func SetVPNInterfaceName(name string) { + vpnInitMutex.Lock() + defer vpnInitMutex.Unlock() + vpnInterfaceName = name + + if name != "" { + log.Infof("VPN interface name set to %s for route exclusion", name) + } +} diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go new file mode 100644 index 000000000..93d8e18ef --- /dev/null +++ b/client/net/hooks/hooks.go @@ -0,0 +1,93 @@ +package hooks + +import ( + "net/netip" + "slices" + "sync" + + "github.com/google/uuid" +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} + +type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error +type CloseHookFunc func(connID ConnectionID) error +type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + +var ( + hooksMutex sync.RWMutex + + writeHooks []WriteHookFunc + closeHooks []CloseHookFunc + addressRemoveHooks []AddressRemoveHookFunc +) + +// AddWriteHook allows adding a new hook to be executed before writing/dialing. +func AddWriteHook(hook WriteHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = append(writeHooks, hook) +} + +// AddCloseHook allows adding a new hook to be executed on connection close. +func AddCloseHook(hook CloseHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = append(closeHooks, hook) +} + +// RemoveWriteHooks removes all write hooks. +func RemoveWriteHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = nil +} + +// RemoveCloseHooks removes all close hooks. +func RemoveCloseHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = nil +} + +// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddAddressRemoveHook(hook AddressRemoveHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = append(addressRemoveHooks, hook) +} + +// RemoveAddressRemoveHooks removes all listener address hooks. +func RemoveAddressRemoveHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = nil +} + +// GetWriteHooks returns a copy of the current write hooks. +func GetWriteHooks() []WriteHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(writeHooks) +} + +// GetCloseHooks returns a copy of the current close hooks. +func GetCloseHooks() []CloseHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(closeHooks) +} + +// GetAddressRemoveHooks returns a copy of the current listener address remove hooks. +func GetAddressRemoveHooks() []AddressRemoveHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(addressRemoveHooks) +} diff --git a/client/net/listen.go b/client/net/listen.go new file mode 100644 index 000000000..da7262806 --- /dev/null +++ b/client/net/listen.go @@ -0,0 +1,47 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *PacketConn: + // Legacy routing: wrapped connection for hooks + udpConn, ok := c.PacketConn.(*net.UDPConn) + if !ok { + if err := c.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/listen_ios.go b/client/net/listen_ios.go similarity index 100% rename from util/net/listen_ios.go rename to client/net/listen_ios.go diff --git a/util/net/listener.go b/client/net/listener.go similarity index 81% rename from util/net/listener.go rename to client/net/listener.go index f4d769f58..4c2f53c05 100644 --- a/util/net/listener.go +++ b/client/net/listener.go @@ -7,14 +7,12 @@ import ( // ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before // responding via the socket and after closing. This can be used to bypass the VPN for listeners. type ListenerConfig struct { - *net.ListenConfig + net.ListenConfig } // NewListener creates a new ListenerConfig instance. func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } + listener := &ListenerConfig{} listener.init() return listener diff --git a/util/net/listener_init_android.go b/client/net/listener_init_android.go similarity index 100% rename from util/net/listener_init_android.go rename to client/net/listener_init_android.go diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go new file mode 100644 index 000000000..4f8f17ab2 --- /dev/null +++ b/client/net/listener_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (l *ListenerConfig) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/listener_init_linux.go b/client/net/listener_init_linux.go similarity index 100% rename from util/net/listener_init_linux.go rename to client/net/listener_init_linux.go diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go new file mode 100644 index 000000000..a9399b5f1 --- /dev/null +++ b/client/net/listener_init_windows.go @@ -0,0 +1,8 @@ +package net + +func (l *ListenerConfig) init() { + // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses. + // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case + // the interface will be selected that serves the default route. + l.ListenConfig.Control = applyUnicastIFToSocket +} diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go new file mode 100644 index 000000000..0bb5ad67d --- /dev/null +++ b/client/net/listener_listen.go @@ -0,0 +1,153 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() || AdvancedRouting() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := hooks.GenerateConnID() + + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.UDPConn) +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *PacketConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen()) + + addressRemoveHooks := hooks.GetAddressRemoveHooks() + if len(addressRemoveHooks) == 0 { + return + } + + for _, hook := range addressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + +// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality +func WrapPacketConn(conn net.PacketConn) net.PacketConn { + if AdvancedRouting() { + // hooks not required for advanced routing + return conn + } + return &PacketConn{ + PacketConn: conn, + ID: hooks.GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error { + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded { + return nil + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr) + } + + prefix, err := util.GetPrefixFromIP(udpAddr.IP) + if err != nil { + return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err) + } + + log.Debugf("Listener resolved IP for %s: %s", addr, prefix) + + var merr *multierror.Error + for _, hook := range writeHooks { + if err := hook(id, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/listener_listen_ios.go b/client/net/listener_listen_ios.go similarity index 100% rename from util/net/listener_listen_ios.go rename to client/net/listener_listen_ios.go diff --git a/util/net/net.go b/client/net/net.go similarity index 81% rename from util/net/net.go rename to client/net/net.go index fdcf4ee6a..a97de9d59 100644 --- a/util/net/net.go +++ b/client/net/net.go @@ -5,8 +5,6 @@ import ( "math/big" "net" "net/netip" - - "github.com/google/uuid" ) const ( @@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool { return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper } -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -type AddHookFunc func(connID ConnectionID, IP net.IP) error -type RemoveHookFunc func(connID ConnectionID) error - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} - func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP addr := network.Addr().AsSlice() diff --git a/util/net/net_linux.go b/client/net/net_linux.go similarity index 100% rename from util/net/net_linux.go rename to client/net/net_linux.go diff --git a/util/net/net_test.go b/client/net/net_test.go similarity index 100% rename from util/net/net_test.go rename to client/net/net_test.go diff --git a/client/net/net_windows.go b/client/net/net_windows.go new file mode 100644 index 000000000..649d83aaf --- /dev/null +++ b/client/net/net_windows.go @@ -0,0 +1,284 @@ +package net + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "syscall" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + IpUnicastIf = 31 + Ipv6UnicastIf = 31 + + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options + Ipv6V6only = 27 +) + +// GetBestInterfaceFunc is set at runtime to avoid import cycle +var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error) + +// nativeToBigEndian converts a uint32 from native byte order to big-endian +func nativeToBigEndian(v uint32) uint32 { + return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24 +} + +// parseDestinationAddress parses the destination address from various formats +func parseDestinationAddress(network, address string) (netip.Addr, error) { + if address == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + if addrPort, err := netip.ParseAddrPort(address); err == nil { + return addrPort.Addr(), nil + } + + if dest, err := netip.ParseAddr(address); err == nil { + return dest, nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + // No port, treat whole string as host + host = address + } + + if host == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil || len(ips) == 0 { + return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err) + } + + dest, ok := netip.AddrFromSlice(ips[0].IP) + if !ok { + return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP) + } + + if ips[0].Zone != "" { + dest = dest.WithZone(ips[0].Zone) + } + + return dest, nil +} + +func getInterfaceFromZone(zone string) *net.Interface { + if zone == "" { + return nil + } + + idx, err := strconv.Atoi(zone) + if err != nil { + log.Debugf("invalid zone format for Windows (expected numeric): %s", zone) + return nil + } + + iface, err := net.InterfaceByIndex(idx) + if err != nil { + log.Debugf("failed to get interface by index %d from zone: %v", idx, err) + return nil + } + + return iface +} + +type interfaceSelection struct { + iface4 *net.Interface + iface6 *net.Interface +} + +func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection { + iface := getInterfaceFromZone(zone) + if iface == nil { + return nil + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface} + } + return &interfaceSelection{iface4: iface} +} + +func selectInterfaceForUnspecified() (*interfaceSelection, error) { + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + var result interfaceSelection + vpnIfaceName := GetVPNInterfaceName() + + if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil { + result.iface4 = iface4 + } else { + log.Debugf("No IPv4 default route found: %v", err) + } + + if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil { + result.iface6 = iface6 + } else { + log.Debugf("No IPv6 default route found: %v", err) + } + + if result.iface4 == nil && result.iface6 == nil { + return nil, errors.New("no default routes found") + } + + return &result, nil +} + +func selectInterface(dest netip.Addr) (*interfaceSelection, error) { + if zone := dest.Zone(); zone != "" { + if selection := selectInterfaceForZone(dest, zone); selection != nil { + return selection, nil + } + } + + if dest.IsUnspecified() { + return selectInterfaceForUnspecified() + } + + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName()) + if err != nil { + return nil, fmt.Errorf("find route for %s: %w", dest, err) + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface}, nil + } + return &interfaceSelection{iface4: iface}, nil +} + +func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error { + ifaceIndexBE := nativeToBigEndian(uint32(iface.Index)) + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil { + return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error { + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil { + return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error { + // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.) + // Never generic ones (udp, tcp, ip) + + switch { + case strings.HasSuffix(network, "4"): + // IPv4-only socket (udp4, tcp4, ip4) + return setUnicastIfIPv4(fd, network, selection, address) + + case strings.HasSuffix(network, "6"): + // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only + return setUnicastIfIPv6(fd, network, selection, address) + } + + // Shouldn't reach here based on Go's documented behavior + return fmt.Errorf("unexpected network type: %s", network) +} + +func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error { + if selection.iface4 == nil { + return nil + } + + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address) + return nil +} + +func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error { + isDualStack := checkDualStack(fd) + + // For dual-stack sockets, also set the IPv4 option + if isDualStack && selection.iface4 != nil { + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address) + } + + if selection.iface6 == nil { + return nil + } + + if err := setIPv6UnicastIF(fd, selection.iface6); err != nil { + return err + } + + log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address) + return nil +} + +func checkDualStack(fd uintptr) bool { + var v6Only int + v6OnlyLen := int32(unsafe.Sizeof(v6Only)) + err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen) + return err == nil && v6Only == 0 +} + +// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address +func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + dest, err := parseDestinationAddress(network, address) + if err != nil { + return err + } + + dest = dest.Unmap() + + if !dest.IsValid() { + return fmt.Errorf("invalid destination address for %s", address) + } + + selection, err := selectInterface(dest) + if err != nil { + return err + } + + var controlErr error + err = c.Control(func(fd uintptr) { + controlErr = setUnicastIf(fd, network, selection, address) + }) + + if err != nil { + return fmt.Errorf("control: %w", err) + } + + return controlErr +} diff --git a/util/net/protectsocket_android.go b/client/net/protectsocket_android.go similarity index 89% rename from util/net/protectsocket_android.go rename to client/net/protectsocket_android.go index febed8a1e..00071461d 100644 --- a/util/net/protectsocket_android.go +++ b/client/net/protectsocket_android.go @@ -4,6 +4,8 @@ import ( "fmt" "sync" "syscall" + + "github.com/netbirdio/netbird/client/iface/netstack" ) var ( @@ -19,6 +21,9 @@ func SetAndroidProtectSocketFn(fn func(fd int32) bool) { // ControlProtectSocket is a Control function that sets the fwmark on the socket func ControlProtectSocket(_, _ string, c syscall.RawConn) error { + if netstack.IsEnabled() { + return nil + } var aErr error err := c.Control(func(fd uintptr) { androidProtectSocketLock.Lock() diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 2422d2683..7c9fa021a 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -2,7 +2,7 @@ set -eEuo pipefail : ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() @@ -39,7 +39,7 @@ wait_for_message() { info "not waiting for log line ${message@Q} due to zero timeout." elif test -n "${log_file_path}"; then info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) else info "log file unsupported, sleeping for ${timeout} seconds..." sleep "${timeout}" @@ -81,7 +81,7 @@ wait_for_daemon_startup() { login_if_needed() { local timeout="${1}" - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then info "already logged in, skipping 'netbird up'..." else info "logging in..." diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 60835d1cd..841e3c0f7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v5.29.3 +// protoc v6.32.1 // source: daemon.proto package proto @@ -278,6 +278,7 @@ type LoginRequest struct { BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"` ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` + Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -530,6 +531,13 @@ func (x *LoginRequest) GetUsername() string { return "" } +func (x *LoginRequest) GetMtu() int64 { + if x != nil && x.Mtu != nil { + return *x.Mtu + } + return 0 +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -786,8 +794,10 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -834,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } +func (x *StatusRequest) GetWaitForReady() bool { + if x != nil && x.WaitForReady != nil { + return *x.WaitForReady + } + return false +} + type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -1034,6 +1051,7 @@ type GetConfigResponse struct { AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` + Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` @@ -1129,6 +1147,13 @@ func (x *GetConfigResponse) GetWireguardPort() int64 { return 0 } +func (x *GetConfigResponse) GetMtu() int64 { + if x != nil { + return x.Mtu + } + return 0 +} + func (x *GetConfigResponse) GetDisableAutoConnect() bool { if x != nil { return x.DisableAutoConnect @@ -3679,6 +3704,7 @@ type SetConfigRequest struct { // cleanDNSLabels clean map list of DNS labels. CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -3902,6 +3928,13 @@ func (x *SetConfigRequest) GetDnsRouteInterval() *durationpb.Duration { return nil } +func (x *SetConfigRequest) GetMtu() int64 { + if x != nil && x.Mtu != nil { + return *x.Mtu + } + return 0 +} + type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -4575,7 +4608,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xa4\x0e\n" + + "\fEmptyRequest\"\xc3\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4611,7 +4644,8 @@ const file_daemon_proto_rawDesc = "" + "\x15lazyConnectionEnabled\x18\x1c \x01(\bH\x0fR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01B\x13\n" + + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4630,7 +4664,8 @@ const file_daemon_proto_rawDesc = "" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + - "\t_username\"\xb5\x01\n" + + "\t_usernameB\x06\n" + + "\x04_mtu\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -4647,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"g\n" + + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + + "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + + "\r_waitForReady\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -4661,7 +4698,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xa3\x06\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xb5\x06\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -4671,7 +4708,8 @@ const file_daemon_proto_rawDesc = "" + "\fpreSharedKey\x18\x04 \x01(\tR\fpreSharedKey\x12\x1a\n" + "\badminURL\x18\x05 \x01(\tR\badminURL\x12$\n" + "\rinterfaceName\x18\x06 \x01(\tR\rinterfaceName\x12$\n" + - "\rwireguardPort\x18\a \x01(\x03R\rwireguardPort\x12.\n" + + "\rwireguardPort\x18\a \x01(\x03R\rwireguardPort\x12\x10\n" + + "\x03mtu\x18\b \x01(\x03R\x03mtu\x12.\n" + "\x12disableAutoConnect\x18\t \x01(\bR\x12disableAutoConnect\x12*\n" + "\x10serverSSHAllowed\x18\n" + " \x01(\bR\x10serverSSHAllowed\x12*\n" + @@ -4885,7 +4923,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\xef\f\n" + + "\x15SwitchProfileResponse\"\x8e\r\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -4917,7 +4955,8 @@ const file_daemon_proto_rawDesc = "" + "\n" + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + - "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01B\x13\n" + + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" + + "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4934,7 +4973,8 @@ const file_daemon_proto_rawDesc = "" + "\x16_disable_notificationsB\x18\n" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x13\n" + - "\x11_dnsRouteInterval\"\x13\n" + + "\x11_dnsRouteIntervalB\x06\n" + + "\x04_mtu\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + @@ -5202,6 +5242,7 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index fa54071ec..5b27b4d98 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -156,6 +156,8 @@ message LoginRequest { optional string profileName = 30; optional string username = 31; + + optional int64 mtu = 32; } message LoginResponse { @@ -184,6 +186,8 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + optional bool waitForReady = 3; } message StatusResponse{ @@ -223,6 +227,8 @@ message GetConfigResponse { int64 wireguardPort = 7; + int64 mtu = 8; + bool disableAutoConnect = 9; bool serverSSHAllowed = 10; @@ -538,36 +544,36 @@ message SetConfigRequest { string profileName = 2; // managementUrl to authenticate. string managementUrl = 3; - + // adminUrl to manage keys. string adminURL = 4; - + optional bool rosenpassEnabled = 5; - + optional string interfaceName = 6; - + optional int64 wireguardPort = 7; - + optional string optionalPreSharedKey = 8; - + optional bool disableAutoConnect = 9; - + optional bool serverSSHAllowed = 10; - + optional bool rosenpassPermissive = 11; - + optional bool networkMonitor = 12; - + optional bool disable_client_routes = 13; optional bool disable_server_routes = 14; optional bool disable_dns = 15; optional bool disable_firewall = 16; optional bool block_lan_access = 17; - + optional bool disable_notifications = 18; - + optional bool lazyConnectionEnabled = 19; - + optional bool block_inbound = 20; repeated string natExternalIPs = 21; @@ -583,6 +589,7 @@ message SetConfigRequest { optional google.protobuf.Duration dnsRouteInterval = 27; + optional int64 mtu = 28; } message SetConfigResponse{} @@ -633,4 +640,4 @@ message GetFeaturesRequest{} message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; -} \ No newline at end of file +} diff --git a/client/server/server.go b/client/server/server.go index dd842d099..168b297c6 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -63,6 +63,9 @@ type Server struct { mutex sync.Mutex config *profilemanager.Config proto.UnimplementedDaemonServiceServer + clientRunning bool // protected by mutex + clientRunningChan chan struct{} + clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -101,6 +104,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + + if s.clientRunning { + return nil + } + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -170,8 +178,10 @@ func (s *Server) Start() error { return nil } - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) - + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -202,12 +212,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - runningChan chan struct{}, -) { - backOff := getConnectWithBackoff(ctx) - retryStarted := false +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { + defer func() { + s.mutex.Lock() + s.clientRunning = false + s.mutex.Unlock() + }() + if s.config.DisableAutoConnect { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + log.Debugf("run client connection exited with error: %v", err) + } + log.Tracef("client connection exited") + return + } + + backOff := getConnectWithBackoff(ctx) go func() { t := time.NewTicker(24 * time.Hour) for { @@ -216,89 +236,36 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage t.Stop() return case <-t.C: - if retryStarted { - - mgmtState := statusRecorder.GetManagementState() - signalState := statusRecorder.GetSignalState() - if mgmtState.Connected && signalState.Connected { - log.Tracef("resetting status") - retryStarted = false - } else { - log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) - } + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + backOff.Reset() + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) } } } }() runOperation := func() error { - log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.logFile) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - - err := s.connectClient.Run(runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + return err } - if config.DisableAutoConnect { - return backoff.Permanent(err) - } - - if !retryStarted { - retryStarted = true - backOff.Reset() - } - - log.Tracef("client connection exited") - return fmt.Errorf("client connection exited") + log.Tracef("client connection exited gracefully, do not need to retry") + return nil } - err := backoff.Retry(runOperation, backOff) - if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { - log.Errorf("received an error when trying to connect: %v", err) - } else { - log.Tracef("retry canceled") - } -} - -// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries -func getConnectWithBackoff(ctx context.Context) backoff.BackOff { - initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) - maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) - maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) - multiplier := defaultRetryMultiplier - - if envValue := os.Getenv(retryMultiplierVar); envValue != "" { - // parse the multiplier from the environment variable string value to float64 - value, err := strconv.ParseFloat(envValue, 64) - if err != nil { - log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) - } else { - multiplier = value - } + if err := backoff.Retry(runOperation, backOff); err != nil { + log.Errorf("operation failed: %v", err) } - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: initialInterval, - RandomizationFactor: 1, - Multiplier: multiplier, - MaxInterval: maxInterval, - MaxElapsedTime: maxElapsedTime, // 14 days - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// parseEnvDuration parses the environment variable and returns the duration -func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { - if envValue := os.Getenv(envVar); envValue != "" { - if duration, err := time.ParseDuration(envValue); err == nil { - return duration - } - log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + if giveUpChan != nil { + close(giveUpChan) } - return defaultDuration } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -398,6 +365,11 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound + if msg.Mtu != nil { + mtu := uint16(*msg.Mtu) + config.MTU = &mtu + } + if _, err := profilemanager.UpdateConfig(config); err != nil { log.Errorf("failed to update profile config: %v", err) return nil, fmt.Errorf("failed to update profile config: %w", err) @@ -412,7 +384,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(s.rootCtx) + ctx, cancel := context.WithCancel(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -422,11 +394,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(ctx) + state := internal.CtxGetState(s.rootCtx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -482,6 +454,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro // nolint ctx = context.WithValue(ctx, system.DeviceNameCtxKey, msg.Hostname) } + s.mutex.Unlock() config, err := s.getConfig(activeProf) @@ -638,6 +611,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() + if s.clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return nil, err + } + if status == internal.StatusNeedsLogin { + s.actCancel() + } + s.mutex.Unlock() + + return s.waitForUp(callerCtx) + } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -653,16 +640,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } + if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but . + // it should be nil here, but in case it isn't we cancel it. if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) - md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -705,23 +692,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + return s.waitForUp(callerCtx) +} + +// todo: handle potential race conditions +func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) - for { - select { - case <-runningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() - } + select { + case <-s.clientGiveUpChan: + return nil, fmt.Errorf("client gave up to connect") + case <-s.clientRunningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } @@ -995,12 +990,46 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - s.mutex.Lock() - defer s.mutex.Unlock() + clientRunning := s.clientRunning + s.mutex.Unlock() + + if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + return nil, err + } + + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + loop: + for { + select { + case <-s.clientGiveUpChan: + ticker.Stop() + break loop + case <-s.clientRunningChan: + ticker.Stop() + break loop + case <-ticker.C: + status, err := state.Status() + if err != nil { + continue + } + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + continue + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { @@ -1103,6 +1132,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p AdminURL: adminURL.String(), InterfaceName: cfg.WgIface, WireguardPort: int64(cfg.WgPort), + Mtu: int64(cfg.MTU), DisableAutoConnect: cfg.DisableAutoConnect, ServerSSHAllowed: *cfg.ServerSSHAllowed, RosenpassEnabled: cfg.RosenpassEnabled, @@ -1118,45 +1148,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p }, nil } -func (s *Server) onSessionExpire() { - if runtime.GOOS != "windows" { - isUIActive := internal.CheckUIApp() - if !isUIActive && s.config.DisableNotifications != nil && !*s.config.DisableNotifications { - if err := sendTerminalNotification(); err != nil { - log.Errorf("send session expire terminal notification: %v", err) - } - } - } -} - -// sendTerminalNotification sends a terminal notification message -// to inform the user that the NetBird connection session has expired. -func sendTerminalNotification() error { - message := "NetBird connection session expired\n\nPlease re-authenticate to connect to the network." - echoCmd := exec.Command("echo", message) - wallCmd := exec.Command("sudo", "wall") - - echoCmdStdout, err := echoCmd.StdoutPipe() - if err != nil { - return err - } - wallCmd.Stdin = echoCmdStdout - - if err := echoCmd.Start(); err != nil { - return err - } - - if err := wallCmd.Start(); err != nil { - return err - } - - if err := echoCmd.Wait(); err != nil { - return err - } - - return wallCmd.Wait() -} - // AddProfile adds a new profile to the daemon. func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { s.mutex.Lock() @@ -1257,6 +1248,16 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { + log.Tracef("running client connection") + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) + if err := s.connectClient.Run(runningChan); err != nil { + return err + } + return nil +} + func (s *Server) checkProfilesDisabled() bool { // Check if the environment variable is set to disable profiles if s.profilesDisabled { @@ -1274,3 +1275,168 @@ func (s *Server) checkUpdateSettingsDisabled() bool { return false } + +func (s *Server) onSessionExpire() { + if runtime.GOOS != "windows" { + isUIActive := internal.CheckUIApp() + if !isUIActive && s.config.DisableNotifications != nil && !*s.config.DisableNotifications { + if err := sendTerminalNotification(); err != nil { + log.Errorf("send session expire terminal notification: %v", err) + } + } + } +} + +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + +func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { + pbFullStatus := proto.FullStatus{ + ManagementState: &proto.ManagementState{}, + SignalState: &proto.SignalState{}, + LocalPeerState: &proto.LocalPeerState{}, + Peers: []*proto.PeerState{}, + } + + pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL + pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected + if err := fullStatus.ManagementState.Error; err != nil { + pbFullStatus.ManagementState.Error = err.Error() + } + + pbFullStatus.SignalState.URL = fullStatus.SignalState.URL + pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected + if err := fullStatus.SignalState.Error; err != nil { + pbFullStatus.SignalState.Error = err.Error() + } + + pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP + pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey + pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface + pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN + pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive + pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled + pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) + pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) + pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled + + for _, peerState := range fullStatus.Peers { + pbPeerState := &proto.PeerState{ + IP: peerState.IP, + PubKey: peerState.PubKey, + ConnStatus: peerState.ConnStatus.String(), + ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), + Relayed: peerState.Relayed, + LocalIceCandidateType: peerState.LocalIceCandidateType, + RemoteIceCandidateType: peerState.RemoteIceCandidateType, + LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, + RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, + RelayAddress: peerState.RelayServerAddress, + Fqdn: peerState.FQDN, + LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), + BytesRx: peerState.BytesRx, + BytesTx: peerState.BytesTx, + RosenpassEnabled: peerState.RosenpassEnabled, + Networks: maps.Keys(peerState.GetRoutes()), + Latency: durationpb.New(peerState.Latency), + } + pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) + } + + for _, relayState := range fullStatus.Relays { + pbRelayState := &proto.RelayState{ + URI: relayState.URI, + Available: relayState.Err == nil, + } + if err := relayState.Err; err != nil { + pbRelayState.Error = err.Error() + } + pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) + } + + for _, dnsState := range fullStatus.NSGroupStates { + var err string + if dnsState.Error != nil { + err = dnsState.Error.Error() + } + + var servers []string + for _, server := range dnsState.Servers { + servers = append(servers, server.String()) + } + + pbDnsState := &proto.NSGroupState{ + Servers: servers, + Domains: dnsState.Domains, + Enabled: dnsState.Enabled, + Error: err, + } + pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) + } + + return &pbFullStatus +} + +// sendTerminalNotification sends a terminal notification message +// to inform the user that the NetBird connection session has expired. +func sendTerminalNotification() error { + message := "NetBird connection session expired\n\nPlease re-authenticate to connect to the network." + echoCmd := exec.Command("echo", message) + wallCmd := exec.Command("sudo", "wall") + + echoCmdStdout, err := echoCmd.StdoutPipe() + if err != nil { + return err + } + wallCmd.Stdin = echoCmdStdout + + if err := echoCmd.Start(); err != nil { + return err + } + + if err := wallCmd.Start(); err != nil { + return err + } + + if err := echoCmd.Wait(); err != nil { + return err + } + + return wallCmd.Wait() +} diff --git a/client/server/server_test.go b/client/server/server_test.go index 6f7c4a89a..09b0ed499 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,25 +10,25 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/groups" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) { profName := "default" + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: u.String(), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) - err = s.Start() require.NoError(t, err) - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - s.config = &profilemanager.Config{ - ManagementURL: u, - } - upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) + log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } @@ -295,15 +293,20 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + permissionsManagerMock := permissions.NewMockManager(ctrl) + peersManager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - permissionsManagerMock := permissions.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) diff --git a/client/system/info.go b/client/system/info.go index ea3f6063a..a180be4c0 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/shared/management/proto" @@ -95,14 +96,6 @@ func (i *Info) SetFlags( i.LazyConnectionEnabled = lazyConnectionEnabled } -// StaticInfo is an object that contains machine information that does not change -type StaticInfo struct { - SystemSerialNumber string - SystemProductName string - SystemManufacturer string - Environment Environment -} - // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) @@ -180,6 +173,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { + log.Debugf("gathering system information with checks: %d", len(checks)) processCheckPaths := make([]string, 0) for _, check := range checks { processCheckPaths = append(processCheckPaths, check.GetFiles()...) @@ -189,16 +183,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro if err != nil { return nil, err } + log.Debugf("gathering process check information completed") info := GetInfo(ctx) info.Files = files + log.Debugf("all system information gathered successfully") return info, nil } - -// UpdateStaticInfo asynchronously updates static system and platform information -func UpdateStaticInfo() { - go func() { - _ = updateStaticInfo() - }() -} diff --git a/client/system/info_android.go b/client/system/info_android.go index 56fe0741d..78895bfa8 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -15,6 +15,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { kernel := "android" diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index f105ada60..caa344737 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -19,6 +19,10 @@ import ( "github.com/netbirdio/netbird/version" ) +func UpdateStaticInfoAsync() { + go updateStaticInfo() +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { utsname := unix.Utsname{} @@ -41,7 +45,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index bed6711de..8e1353151 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -18,6 +18,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { out := _getInfo() diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 897ec0a35..705c37920 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -10,6 +10,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { diff --git a/client/system/info_linux.go b/client/system/info_linux.go index 9bfc82009..6c7a23b95 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -23,6 +23,10 @@ var ( getSystemInfo = defaultSysInfoImplementation ) +func UpdateStaticInfoAsync() { + go updateStaticInfo() +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { info := _getInfo() @@ -48,7 +52,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 6f05ded20..d7f8f30aa 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -2,187 +2,51 @@ package system import ( "context" - "fmt" "os" "runtime" - "strings" "time" log "github.com/sirupsen/logrus" - "github.com/yusufpapurcu/wmi" - "golang.org/x/sys/windows/registry" "github.com/netbirdio/netbird/version" ) -type Win32_OperatingSystem struct { - Caption string -} - -type Win32_ComputerSystem struct { - Manufacturer string -} - -type Win32_ComputerSystemProduct struct { - Name string -} - -type Win32_BIOS struct { - SerialNumber string +func UpdateStaticInfoAsync() { + go updateStaticInfo() } // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - osName, osVersion := getOSNameAndVersion() - buildVersion := getBuildVersion() - - addrs, err := networkAddresses() - if err != nil { - log.Warnf("failed to discover network addresses: %s", err) - } - start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ Kernel: "windows", - OSVersion: osVersion, + OSVersion: si.OSVersion, Platform: "unknown", - OS: osName, + OS: si.OSName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), - KernelVersion: buildVersion, - NetworkAddresses: addrs, + KernelVersion: si.BuildVersion, SystemSerialNumber: si.SystemSerialNumber, SystemProductName: si.SystemProductName, SystemManufacturer: si.SystemManufacturer, Environment: si.Environment, } + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } else { + gio.NetworkAddresses = addrs + } + systemHostname, _ := os.Hostname() gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) - return gio } - -func sysInfo() (serialNumber string, productName string, manufacturer string) { - var err error - serialNumber, err = sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - productName, err = sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err = sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - return serialNumber, productName, manufacturer -} - -func getOSNameAndVersion() (string, string) { - var dst []Win32_OperatingSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - log.Error(err) - return "Windows", getBuildVersion() - } - - if len(dst) == 0 { - return "Windows", getBuildVersion() - } - - split := strings.Split(dst[0].Caption, " ") - - if len(split) <= 3 { - return "Windows", getBuildVersion() - } - - name := split[1] - version := split[2] - if split[2] == "Server" { - name = fmt.Sprintf("%s %s", split[1], split[2]) - version = split[3] - } - - return name, version -} - -func getBuildVersion() string { - k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) - if err != nil { - log.Error(err) - return "0.0.0.0" - } - defer func() { - deferErr := k.Close() - if deferErr != nil { - log.Error(deferErr) - } - }() - - major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") - if err != nil { - log.Error(err) - } - minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") - if err != nil { - log.Error(err) - } - build, _, err := k.GetStringValue("CurrentBuildNumber") - if err != nil { - log.Error(err) - } - // Update Build Revision - ubr, _, err := k.GetIntegerValue("UBR") - if err != nil { - log.Error(err) - } - ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) - return ver -} - -func sysNumber() (string, error) { - var dst []Win32_BIOS - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].SerialNumber, nil -} - -func sysProductName() (string, error) { - var dst []Win32_ComputerSystemProduct - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - // `ComputerSystemProduct` could be empty on some virtualized systems - if len(dst) < 1 { - return "unknown", nil - } - return dst[0].Name, nil -} - -func sysManufacturer() (string, error) { - var dst []Win32_ComputerSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].Manufacturer, nil -} diff --git a/client/system/static_info.go b/client/system/static_info.go index f178ec932..12a2663a1 100644 --- a/client/system/static_info.go +++ b/client/system/static_info.go @@ -3,12 +3,7 @@ package system import ( - "context" "sync" - "time" - - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" ) var ( @@ -16,25 +11,26 @@ var ( once sync.Once ) -func updateStaticInfo() StaticInfo { +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment + + // Windows specific fields + OSName string + OSVersion string + BuildVersion string +} + +func updateStaticInfo() { once.Do(func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() - wg.Done() - }() - go func() { - staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) - wg.Done() - }() - go func() { - staticInfo.Environment.Platform = detect_platform.Detect(ctx) - wg.Done() - }() - wg.Wait() + staticInfo = newStaticInfo() }) +} + +func getStaticInfo() StaticInfo { + updateStaticInfo() return staticInfo } diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go deleted file mode 100644 index faa3e700b..000000000 --- a/client/system/static_info_stub.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build android || freebsd || ios - -package system - -// updateStaticInfo returns an empty implementation for unsupported platforms -func updateStaticInfo() StaticInfo { - return StaticInfo{} -} diff --git a/client/system/static_info_update.go b/client/system/static_info_update.go new file mode 100644 index 000000000..af8b1e266 --- /dev/null +++ b/client/system/static_info_update.go @@ -0,0 +1,35 @@ +//go:build (linux && !android) || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +func newStaticInfo() StaticInfo { + si := StaticInfo{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + si.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + si.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + return si +} diff --git a/client/system/static_info_update_windows.go b/client/system/static_info_update_windows.go new file mode 100644 index 000000000..5f232c1de --- /dev/null +++ b/client/system/static_info_update_windows.go @@ -0,0 +1,184 @@ +package system + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "github.com/yusufpapurcu/wmi" + "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +type Win32_OperatingSystem struct { + Caption string +} + +type Win32_ComputerSystem struct { + Manufacturer string +} + +type Win32_ComputerSystemProduct struct { + Name string +} + +type Win32_BIOS struct { + SerialNumber string +} + +func newStaticInfo() StaticInfo { + si := StaticInfo{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() + wg.Done() + }() + wg.Add(1) + go func() { + si.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + wg.Add(1) + go func() { + si.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Add(1) + go func() { + si.OSName, si.OSVersion = getOSNameAndVersion() + wg.Done() + }() + wg.Add(1) + go func() { + si.BuildVersion = getBuildVersion() + wg.Done() + }() + wg.Wait() + return si +} + +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + +func sysNumber() (string, error) { + var dst []Win32_BIOS + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].SerialNumber, nil +} + +func sysProductName() (string, error) { + var dst []Win32_ComputerSystemProduct + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + // `ComputerSystemProduct` could be empty on some virtualized systems + if len(dst) < 1 { + return "unknown", nil + } + return dst[0].Name, nil +} + +func sysManufacturer() (string, error) { + var dst []Win32_ComputerSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].Manufacturer, nil +} + +func getOSNameAndVersion() (string, string) { + var dst []Win32_OperatingSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + log.Error(err) + return "Windows", getBuildVersion() + } + + if len(dst) == 0 { + return "Windows", getBuildVersion() + } + + split := strings.Split(dst[0].Caption, " ") + + if len(split) <= 3 { + return "Windows", getBuildVersion() + } + + name := split[1] + version := split[2] + if split[2] == "Server" { + name = fmt.Sprintf("%s %s", split[1], split[2]) + version = split[3] + } + + return name, version +} + +func getBuildVersion() string { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) + if err != nil { + log.Error(err) + return "0.0.0.0" + } + defer func() { + deferErr := k.Close() + if deferErr != nil { + log.Error(deferErr) + } + }() + + major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") + if err != nil { + log.Error(err) + } + minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") + if err != nil { + log.Error(err) + } + build, _, err := k.GetStringValue("CurrentBuildNumber") + if err != nil { + log.Error(err) + } + // Update Build Revision + ubr, _, err := k.GetIntegerValue("UBR") + if err != nil { + log.Error(err) + } + ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) + return ver +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index f43606de1..25d7380a9 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -257,6 +257,7 @@ type serviceClient struct { iPreSharedKey *widget.Entry iInterfaceName *widget.Entry iInterfacePort *widget.Entry + iMTU *widget.Entry // switch elements for settings form sRosenpassPermissive *widget.Check @@ -272,6 +273,7 @@ type serviceClient struct { RosenpassPermissive bool interfaceName string interfacePort int + mtu uint16 networkMonitor bool disableDNS bool disableClientRoutes bool @@ -413,6 +415,7 @@ func (s *serviceClient) showSettingsUI() { s.iPreSharedKey = widget.NewPasswordEntry() s.iInterfaceName = widget.NewEntry() s.iInterfacePort = widget.NewEntry() + s.iMTU = widget.NewEntry() s.sRosenpassPermissive = widget.NewCheck("Enable Rosenpass permissive mode", nil) @@ -446,6 +449,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form { {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, + {Text: "MTU", Widget: s.iMTU}, {Text: "Management URL", Widget: s.iMngURL}, {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Log File", Widget: s.iLogFile}, @@ -482,6 +486,21 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } + var mtu int64 + mtuText := strings.TrimSpace(s.iMTU.Text) + if mtuText != "" { + var err error + mtu, err = strconv.ParseInt(mtuText, 10, 64) + if err != nil { + dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings) + return + } + if mtu < iface.MinMTU || mtu > iface.MaxMTU { + dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings) + return + } + } + iMngURL := strings.TrimSpace(s.iMngURL.Text) defer s.wSettings.Close() @@ -490,6 +509,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form { if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || s.RosenpassPermissive != s.sRosenpassPermissive.Checked || s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || + s.mtu != uint16(mtu) || s.networkMonitor != s.sNetworkMonitor.Checked || s.disableDNS != s.sDisableDNS.Checked || s.disableClientRoutes != s.sDisableClientRoutes.Checked || @@ -498,6 +518,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form { s.managementURL = iMngURL s.preSharedKey = s.iPreSharedKey.Text + s.mtu = uint16(mtu) currUser, err := user.Current() if err != nil { @@ -516,6 +537,9 @@ func (s *serviceClient) getSettingsForm() *widget.Form { req.RosenpassPermissive = &s.sRosenpassPermissive.Checked req.InterfaceName = &s.iInterfaceName.Text req.WireguardPort = &port + if mtu > 0 { + req.Mtu = &mtu + } req.NetworkMonitor = &s.sNetworkMonitor.Checked req.DisableDns = &s.sDisableDNS.Checked req.DisableClientRoutes = &s.sDisableClientRoutes.Checked @@ -539,27 +563,28 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + log.Errorf("get service status: %v", err) + dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) return } - } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + return + } + } + }() } }, OnCancel: func() { @@ -1088,6 +1113,7 @@ func (s *serviceClient) getSrvConfig() { s.RosenpassPermissive = cfg.RosenpassPermissive s.interfaceName = cfg.WgIface s.interfacePort = cfg.WgPort + s.mtu = cfg.MTU s.networkMonitor = *cfg.NetworkMonitor s.disableDNS = cfg.DisableDNS @@ -1100,6 +1126,12 @@ func (s *serviceClient) getSrvConfig() { s.iPreSharedKey.SetText(cfg.PreSharedKey) s.iInterfaceName.SetText(cfg.WgIface) s.iInterfacePort.SetText(strconv.Itoa(cfg.WgPort)) + if cfg.MTU != 0 { + s.iMTU.SetText(strconv.Itoa(int(cfg.MTU))) + } else { + s.iMTU.SetText("") + s.iMTU.SetPlaceHolder(strconv.Itoa(int(iface.DefaultMTU))) + } s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive) if !cfg.RosenpassEnabled { s.sRosenpassPermissive.Disable() @@ -1160,6 +1192,12 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.WgPort = iface.DefaultWgPort } + if cfg.Mtu != 0 { + config.MTU = uint16(cfg.Mtu) + } else { + config.MTU = iface.DefaultMTU + } + config.DisableAutoConnect = cfg.DisableAutoConnect config.ServerSSHAllowed = &cfg.ServerSSHAllowed config.RosenpassEnabled = cfg.RosenpassEnabled diff --git a/flow/client/client.go b/flow/client/client.go index 949824065..603fd6882 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -20,9 +20,9 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) type GRPCClient struct { diff --git a/go.mod b/go.mod index 771922063..c880ace3e 100644 --- a/go.mod +++ b/go.mod @@ -6,26 +6,24 @@ require ( cunicu.li/go-rosenpass v0.4.0 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cloudflare/circl v1.3.3 // indirect - github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 - github.com/pion/ice/v3 v3.0.2 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.37.0 - golang.org/x/sys v0.32.0 + golang.org/x/crypto v0.40.0 + golang.org/x/sys v0.34.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.64.1 - google.golang.org/protobuf v1.36.6 + google.golang.org/grpc v1.73.0 + google.golang.org/protobuf v1.36.8 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -48,6 +46,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.8 github.com/godbus/dbus/v5 v5.1.0 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 @@ -63,17 +62,19 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e + github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 - github.com/pion/logging v0.2.2 + github.com/pion/ice/v4 v4.0.0-00010101000000-000000000000 + github.com/pion/logging v0.2.4 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 - github.com/pion/transport/v3 v3.0.1 + github.com/pion/stun/v3 v3.0.0 + github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.48.2 @@ -94,18 +95,18 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 - go.opentelemetry.io/otel v1.26.0 + go.opentelemetry.io/otel v1.35.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.26.0 - go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.opentelemetry.io/otel/metric v1.35.0 + go.opentelemetry.io/otel/sdk/metric v1.35.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.39.0 - golang.org/x/oauth2 v0.27.0 - golang.org/x/sync v0.13.0 - golang.org/x/term v0.31.0 + golang.org/x/net v0.42.0 + golang.org/x/oauth2 v0.28.0 + golang.org/x/sync v0.16.0 + golang.org/x/term v0.33.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -118,7 +119,7 @@ require ( require ( cloud.google.com/go/auth v0.3.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.3.0 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect @@ -214,8 +215,10 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect - github.com/pion/mdns v0.0.12 // indirect + github.com/pion/dtls/v3 v3.0.7 // indirect + github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/transport/v2 v2.2.4 // indirect + github.com/pion/turn/v4 v4.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect @@ -232,22 +235,23 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect - go.opentelemetry.io/otel/sdk v1.26.0 // indirect - go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.opentelemetry.io/otel/sdk v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/mod v0.25.0 // indirect + golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) @@ -260,6 +264,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e +replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index b70a6b84c..1b6cdd0a9 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= @@ -250,8 +250,8 @@ github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -506,10 +506,10 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= -github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -553,21 +553,29 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= +github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= +github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= -github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= +github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= +github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= +github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= -github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= +github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= +github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -595,8 +603,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0 github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= @@ -689,6 +697,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -720,26 +730,28 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= -go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= -go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= -go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= -go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= -go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= -go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= -go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= -go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= -go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -767,8 +779,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -814,8 +826,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -861,8 +873,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -876,8 +888,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= -golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -891,8 +903,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -960,8 +972,8 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -969,8 +981,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= -golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= +golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -984,8 +996,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1048,8 +1060,8 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.8-0.20211022200916-316ba0b74098/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1132,10 +1144,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM= +google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1156,8 +1169,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1172,11 +1185,10 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2d7c65cbe..cfec1000e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -328,6 +328,45 @@ delete_auto_service_user() { echo "$PARSED_RESPONSE" } +delete_default_zitadel_admin() { + INSTANCE_URL=$1 + PAT=$2 + + # Search for the default zitadel-admin user + RESPONSE=$( + curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + -d '{ + "queries": [ + { + "userNameQuery": { + "userName": "zitadel-admin@", + "method": "TEXT_QUERY_METHOD_STARTS_WITH" + } + } + ] + }' + ) + + DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty') + + if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then + echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID" + + RESPONSE=$( + curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + ) + PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"') + handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE" + + else + echo "Default zitadel-admin user not found: $RESPONSE" + fi +} + init_zitadel() { echo -e "\nInitializing Zitadel with NetBird's applications\n" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" @@ -346,6 +385,9 @@ init_zitadel() { echo -n "Waiting for Zitadel to become ready " wait_api "$INSTANCE_URL" "$PAT" + echo "Deleting default zitadel-admin user..." + delete_default_zitadel_admin "$INSTANCE_URL" "$PAT" + # create the zitadel project echo "Creating new zitadel project" PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index 23fd760aa..f7fa4a9d0 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -17,7 +17,7 @@ upstream signal { server 127.0.0.1:10000; } upstream management { - # insert the grpc+http port of your signal container here + # insert the grpc+http port of your management container here server 127.0.0.1:8012; } @@ -75,4 +75,4 @@ server { ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem; -} \ No newline at end of file +} diff --git a/management/README.md b/management/README.md index 1122a9e76..c70285d43 100644 --- a/management/README.md +++ b/management/README.md @@ -111,3 +111,6 @@ Generate gRpc code: #!/bin/bash protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=. ``` + + + diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 071247938..af860920f 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -26,7 +26,11 @@ func (s *BaseServer) JobManager() *server.JobManager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore()) + integratedPeerValidator, err := integrations.NewIntegratedValidator( + context.Background(), + s.PeersManager(), + s.SettingsManager(), + s.EventStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 77f899aa4..016706b3b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -105,6 +105,8 @@ type DefaultAccountManager struct { accountUpdateLocks sync.Map updateAccountPeersBufferInterval atomic.Int64 + loginFilter *loginFilter + disableDefaultPolicy bool } @@ -214,6 +216,7 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, + loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, } @@ -300,9 +303,6 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // User that performs the update has to belong to the account. // Returns an updated Settings func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) if err != nil { return nil, fmt.Errorf("failed to validate user permissions: %w", err) @@ -348,13 +348,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } + if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil { + return err + } + if updateAccountPeers || groupsUpdated { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } } - return transaction.SaveAccountSettings(ctx, accountID, newSettings) + return nil }) if err != nil { return nil, err @@ -498,8 +502,6 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID) //nolint ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource)) - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { @@ -535,9 +537,6 @@ func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) @@ -678,8 +677,6 @@ func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheSto // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err @@ -1048,9 +1045,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) @@ -1143,12 +1137,20 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai } func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID - err := am.Store.SaveUser(ctx, newUser) + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID) + if err != nil { + return "", err + } + + if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired { + newUser.Blocked = true + newUser.PendingApproval = true + } + + err = am.Store.SaveUser(ctx, newUser) if err != nil { return "", err } @@ -1158,7 +1160,11 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, return "", err } - am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) + if newUser.PendingApproval { + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true}) + } else { + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) + } return domainAccountID, nil } @@ -1357,13 +1363,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId) - defer func() { - if unlockAccount != nil { - unlockAccount() - } - }() - var addNewGroups []string var removeOldGroups []string var hasChanges bool @@ -1426,8 +1425,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return fmt.Errorf("error incrementing network serial: %w", err) } } - unlockAccount() - unlockAccount = nil return nil }) @@ -1636,17 +1633,16 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } +func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { + return am.loginFilter.allowLogin(wgPubKey, metahash) +} + func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { start := time.Now() defer func() { log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start)) }() - accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) @@ -1657,22 +1653,18 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } + metahash := metaHash(meta, realIP.String()) + am.loginFilter.addLogin(peerPubKey, metahash) + return peer, netMap, postureChecks, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { - accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } - return nil - } func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { @@ -1681,12 +1673,6 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - unlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer unlock() - - unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer unlockPeer() - _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) @@ -1731,7 +1717,9 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err) continue } - peers = append(peers, peer) + if peer.UserID != "" { + peers = append(peers, peer) + } } if len(peers) > 0 { err := am.expireAndUpdatePeers(ctx, accountID, peers) @@ -1827,6 +1815,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, + Extra: &types.ExtraSettings{ + UserApprovalRequired: true, + }, }, Onboarding: types.AccountOnboarding{ OnboardingFlowPending: true, @@ -1933,6 +1924,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, + Extra: &types.ExtraSettings{ + UserApprovalRequired: true, + }, }, } @@ -2118,9 +2112,6 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee } func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) if err != nil { return fmt.Errorf("validate user permissions: %w", err) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 13154b98c..198f08bb8 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -32,6 +32,8 @@ type Manager interface { DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) + RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) @@ -77,7 +79,7 @@ type Manager interface { DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) @@ -126,4 +128,5 @@ type Manager interface { CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) + AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 2770cfdb0..14fb27f42 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/prometheus/client_golang/prometheus/push" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -3046,19 +3048,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3121,19 +3118,14 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3196,24 +3188,44 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } } +func TestMain(m *testing.M) { + exitCode := m.Run() + + if exitCode == 0 && os.Getenv("CI") == "true" { + runID := os.Getenv("GITHUB_RUN_ID") + storeEngine := os.Getenv("NETBIRD_STORE_ENGINE") + err := push.New("http://localhost:9091", "account_manager_benchmark"). + Collector(testing_tools.BenchmarkDuration). + Grouping("ci_run", runID). + Grouping("store_engine", storeEngine). + Push() + if err != nil { + log.Printf("Failed to push metrics: %v", err) + } else { + time.Sleep(1 * time.Minute) + _ = push.New("http://localhost:9091", "account_manager_benchmark"). + Grouping("ci_run", runID). + Grouping("store_engine", storeEngine). + Delete() + } + } + + os.Exit(exitCode) +} + func Test_GetCreateAccountByPrivateDomain(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -3594,3 +3606,93 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { require.Error(t, err, "should fail with invalid peer ID") }) } + +func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create a domain-based account with user approval enabled + existingAccountID := "existing-account" + account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false) + account.Settings.Extra = &types.ExtraSettings{ + UserApprovalRequired: true, + } + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Set the account as domain primary account + account.IsDomainPrimaryAccount = true + account.DomainCategory = types.PrivateCategory + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Test adding new user to existing account with approval required + newUserID := "new-user-id" + userAuth := nbcontext.UserAuth{ + UserId: newUserID, + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + + acc, err := manager.Store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain") + require.Equal(t, "example.com", acc.Domain, "Account domain should match") + + returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) + require.NoError(t, err) + require.Equal(t, existingAccountID, returnedAccountID) + + // Verify user was created with pending approval + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) + require.NoError(t, err) + assert.True(t, user.Blocked, "User should be blocked when approval is required") + assert.True(t, user.PendingApproval, "User should be pending approval") + assert.Equal(t, existingAccountID, user.AccountID) +} + +func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create a domain-based account without user approval + ownerUserAuth := nbcontext.UserAuth{ + UserId: "owner-user", + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth) + require.NoError(t, err) + + // Modify the account to disable user approval + account, err := manager.Store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + account.Settings.Extra = &types.ExtraSettings{ + UserApprovalRequired: false, + } + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Test adding new user to existing account without approval required + newUserID := "new-user-id" + userAuth := nbcontext.UserAuth{ + UserId: newUserID, + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + + returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) + require.NoError(t, err) + require.Equal(t, existingAccountID, returnedAccountID) + + // Verify user was created without pending approval + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) + require.NoError(t, err) + assert.False(t, user.Blocked, "User should not be blocked when approval is not required") + assert.False(t, user.PendingApproval, "User should not be pending approval") + assert.Equal(t, existingAccountID, user.AccountID) +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 6b4cab04a..f97e35fa9 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -177,6 +177,8 @@ const ( AccountNetworkRangeUpdated Activity = 87 PeerIPUpdated Activity = 88 + UserApproved Activity = 89 + UserRejected Activity = 90 JobCreatedByUser Activity = 89 @@ -288,6 +290,9 @@ var activityMap = map[Activity]Code{ PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, JobCreatedByUser: {"Create Job for peer", "peer.job.create"}, + + UserApproved: {"User approved", "user.approve"}, + UserRejected: {"User rejected", "user.reject"}, } // StringCode returns a string code of the activity diff --git a/management/server/auth/jwt/extractor.go b/management/server/auth/jwt/extractor.go index fab429125..d270d0ff1 100644 --- a/management/server/auth/jwt/extractor.go +++ b/management/server/auth/jwt/extractor.go @@ -5,7 +5,7 @@ import ( "net/url" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" nbcontext "github.com/netbirdio/netbird/management/server/context" diff --git a/management/server/auth/jwt/validator.go b/management/server/auth/jwt/validator.go index 5b38ca786..239447b96 100644 --- a/management/server/auth/jwt/validator.go +++ b/management/server/auth/jwt/validator.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" ) @@ -63,12 +63,10 @@ type Validator struct { } var ( - errKeyNotFound = errors.New("unable to find appropriate key") - errInvalidAudience = errors.New("invalid audience") - errInvalidIssuer = errors.New("invalid issuer") - errTokenEmpty = errors.New("required authorization token not found") - errTokenInvalid = errors.New("token is invalid") - errTokenParsing = errors.New("token could not be parsed") + errKeyNotFound = errors.New("unable to find appropriate key") + errTokenEmpty = errors.New("required authorization token not found") + errTokenInvalid = errors.New("token is invalid") + errTokenParsing = errors.New("token could not be parsed") ) func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator { @@ -88,24 +86,6 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { - // Verify 'aud' claim - var checkAud bool - for _, audience := range v.audienceList { - checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) - if checkAud { - break - } - } - if !checkAud { - return token, errInvalidAudience - } - - // Verify 'issuer' claim - checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false) - if !checkIss { - return token, errInvalidIssuer - } - // If keys are rotated, verify the keys prior to token validation if v.idpSignkeyRefreshEnabled { // If the keys are invalid, retrieve new ones @@ -144,7 +124,7 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { } // ValidateAndParse validates the token and returns the parsed token -func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { +func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // If we get here, the required token is missing @@ -153,7 +133,13 @@ func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To } // Now parse the token - parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx)) + parsedToken, err := jwt.Parse( + token, + v.getKeyFunc(ctx), + jwt.WithAudience(v.audienceList...), + jwt.WithIssuer(v.issuer), + jwt.WithIssuedAt(), + ) // Check if there was an error in parsing... if err != nil { diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index 53d479c90..ece9dc321 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -7,7 +7,7 @@ import ( "fmt" "hash/crc32" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/netbirdio/netbird/base62" nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go index bc7066548..30a7a7161 100644 --- a/management/server/auth/manager_mock.go +++ b/management/server/auth/manager_mock.go @@ -3,7 +3,7 @@ package auth import ( "context" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index 55fb1e31a..c8015eb37 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/management/server/dns.go b/management/server/dns.go index 12aa6e21c..6b73dbd0e 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -20,29 +20,9 @@ import ( // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { - CustomZones sync.Map NameServerGroups sync.Map } -// GetCustomZone retrieves a cached custom zone -func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { - if c == nil { - return nil, false - } - if value, ok := c.CustomZones.Load(key); ok { - return value.(*proto.CustomZone), true - } - return nil, false -} - -// SetCustomZone stores a custom zone in the cache -func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { - if c == nil { - return - } - c.CustomZones.Store(key, value) -} - // GetNameServerGroup retrieves a cached name server group func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { if c == nil { @@ -113,11 +93,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave); err != nil { return err } - return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -212,14 +192,8 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC } for _, zone := range update.CustomZones { - cacheKey := zone.Domain - if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { - protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) - } else { - protoZone := convertToProtoCustomZone(zone) - cache.SetCustomZone(cacheKey, protoZone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } for _, nsGroup := range update.NameServerGroups { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 19b89f574..a4be99acb 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -474,15 +474,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Results should be different for different inputs") } - // Verify that the cache contains elements from both configs - if _, exists := cache.GetCustomZone("example.com"); !exists { - t.Errorf("Cache should contain custom zone for example.com") - } - - if _, exists := cache.GetCustomZone("example.org"); !exists { - t.Errorf("Cache should contain custom zone for example.org") - } - if _, exists := cache.GetNameServerGroup("group1"); !exists { t.Errorf("Cache should contain name server group 'group1'") } diff --git a/management/server/group.go b/management/server/group.go index 915a87086..487cb6d97 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -67,9 +67,6 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, // CreateGroup object of the peers func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) if err != nil { return status.NewPermissionValidationError(err) @@ -96,10 +93,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { - return err - } - if err := transaction.CreateGroup(ctx, newGroup); err != nil { return status.Errorf(status.Internal, "failed to create group: %v", err) } @@ -109,7 +102,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err) } } - return nil + + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -128,9 +122,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use // UpdateGroup object of the peers func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) if err != nil { return status.NewPermissionValidationError(err) @@ -176,11 +167,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.UpdateGroup(ctx, newGroup); err != nil { return err } - return transaction.UpdateGroup(ctx, newGroup) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -211,35 +202,45 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { + var globalErr error + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) + + if err = transaction.CreateGroup(ctx, newGroup); err != nil { + return err + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return err + } + groupIDs = append(groupIDs, newGroup.ID) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + return nil + }) if err != nil { - return err + log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) + if len(groupIDs) == 1 { + return err + } + globalErr = errors.Join(globalErr, err) + // continue updating other groups } + } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { - return err - } - - return transaction.CreateGroups(ctx, accountID, groupsToSave) - }) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err } @@ -252,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return nil + return globalErr } // UpdateGroups updates groups in the account. @@ -269,35 +270,45 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { + var globalErr error + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) - groupIDs = append(groupIDs, newGroup.ID) + + if err = transaction.UpdateGroup(ctx, newGroup); err != nil { + return err + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return err + } events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + groupIDs = append(groupIDs, newGroup.ID) + + return nil + }) if err != nil { - return err + log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) + if len(groups) == 1 { + return err + } + globalErr = errors.Join(globalErr, err) + // continue updating other groups } + } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { - return err - } - - return transaction.UpdateGroups(ctx, accountID, groupsToSave) - }) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err } @@ -310,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return nil + return globalErr } // prepareGroupEvents prepares a list of event functions to be stored. @@ -382,8 +393,6 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } @@ -423,11 +432,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil { return err } - return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -442,9 +451,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var updateAccountPeers bool var err error @@ -454,11 +460,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil { return err } - return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -473,9 +479,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupAddResource appends resource to the group func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var group *types.Group var updateAccountPeers bool var err error @@ -495,11 +498,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.UpdateGroup(ctx, group); err != nil { return err } - return transaction.UpdateGroup(ctx, group) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -514,9 +517,6 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var updateAccountPeers bool var err error @@ -526,11 +526,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil { return err } - return transaction.RemovePeerFromGroup(ctx, peerID, groupID) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -545,9 +545,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, // GroupDeleteResource removes resource from the group func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var group *types.Group var updateAccountPeers bool var err error @@ -567,11 +564,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.UpdateGroup(ctx, group); err != nil { return err } - return transaction.UpdateGroup(ctx, group) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -607,13 +604,6 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st newGroup.ID = xid.New().String() } - for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) - if err != nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } - } - return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 1626a0464..31ff29cbc 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, - newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.SkipAutoApply, ) require.NoError(t, err) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 65e931f18..ce0de5b9c 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/netip" + "os" "strings" "sync" "time" @@ -40,21 +41,30 @@ import ( internalStatus "github.com/netbirdio/netbird/shared/management/status" ) +const ( + envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" + envBlockPeers = "NB_BLOCK_SAME_PEERS" +) + // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager - jobManager *JobManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager - peerLocks sync.Map - authManager auth.Manager - integratedPeerValidator integrated_validator.IntegratedValidator + + peersUpdateManager *PeersUpdateManager + jobManager *JobManager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager + peerLocks sync.Map + authManager auth.Manager + + logBlockedPeers bool + blockPeersWithSameConfig bool + integratedPeerValidator integrated_validator.IntegratedValidator } // NewServer creates a new Management server @@ -85,19 +95,24 @@ func NewServer( } } + logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" + blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + return &GRPCServer{ wgKey: key, // peerKey -> event channel - peersUpdateManager: peersUpdateManager, - jobManager: jobManager, - accountManager: accountManager, - settingsManager: settingsManager, - config: config, - secretsManager: secretsManager, - authManager: authManager, - appMetrics: appMetrics, - ephemeralManager: ephemeralManager, - integratedPeerValidator: integratedPeerValidator, + peersUpdateManager: peersUpdateManager, + jobManager: jobManager, + accountManager: accountManager, + settingsManager: settingsManager, + config: config, + secretsManager: secretsManager, + authManager: authManager, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + logBlockedPeers: logBlockedPeers, + blockPeersWithSameConfig: blockPeersWithSameConfig, + integratedPeerValidator: integratedPeerValidator, }, nil } @@ -177,9 +192,6 @@ func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error { // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { reqStart := time.Now() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequest() - } ctx := srv.Context() @@ -188,6 +200,27 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if err != nil { return err } + + realIP := getRealIP(ctx) + sRealIP := realIP.String() + peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) + metahashed := metaHash(peerMeta, sRealIP) + if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() + } + if s.logBlockedPeers { + log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) + } + if s.blockPeersWithSameConfig { + return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) + } + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequest() + } + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) @@ -211,14 +244,12 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - realIP := getRealIP(ctx) - log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) - + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) return mapError(ctx, err) @@ -236,7 +267,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } unlock() @@ -335,6 +366,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -473,6 +505,9 @@ func mapError(ctx context.Context, err error) error { default: } } + if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) { + return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error()) + } log.WithContext(ctx).Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } @@ -564,16 +599,9 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() - defer func() { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) - } - }() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequest() - } realIP := getRealIP(ctx) - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + sRealIP := realIP.String() + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -581,6 +609,24 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, err } + peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) + metahashed := metaHash(peerMeta, sRealIP) + if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if s.logBlockedPeers { + log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) + } + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() + } + if s.blockPeersWithSameConfig { + return nil, internalStatus.ErrPeerAlreadyLoggedIn + } + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + //nolint ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) @@ -591,6 +637,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + defer func() { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) + } + }() + if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) @@ -611,7 +663,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), - Meta: extractPeerMeta(ctx, loginReq.GetMeta()), + Meta: peerMeta, UserID: userID, SetupKey: loginReq.GetSetupKey(), ConnectionIP: realIP, @@ -1077,8 +1129,6 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (* return nil, mapError(ctx, err) } - s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) - log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start)) return &proto.Empty{}, nil diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 9f2afe29d..f1552d0ea 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -11,11 +11,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -198,6 +198,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.Extra != nil { settings.Extra = &types.ExtraSettings{ PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, @@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: settings.Extra.UserApprovalRequired, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsGroups: settings.Extra.FlowGroups, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 1dad33a6f..4b9b79fdc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -15,11 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initAccountsTestData(t *testing.T, account *types.Account) *handler { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 6c301aa72..1fd3b7f9a 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -488,33 +488,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn } return &api.PeerBatch{ - CreatedAt: peer.CreatedAt, - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.GetLastLogin(), - LoginExpired: peer.Status.LoginExpired, - AccessiblePeersCount: accessiblePeersCount, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, - + CreatedAt: peer.CreatedAt, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.GetLastLogin(), + LoginExpired: peer.Status.LoginExpired, + AccessiblePeersCount: accessiblePeersCount, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, + Ephemeral: peer.Ephemeral, } } diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 7950db1e8..7bb6f2372 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -8,17 +8,19 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/route" ) const failedToConvertRoute = "failed to convert route to response: %v" +const exitNodeCIDR = "0.0.0.0/0" + // handler is the routes handler of the account type handler struct { accountManager account.Manager @@ -124,8 +126,16 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { accessControlGroupIds = *req.AccessControlGroups } + // Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes) + skipAutoApply := false + if req.SkipAutoApply != nil { + skipAutoApply = *req.SkipAutoApply + } else if newPrefix.String() == exitNodeCIDR { + skipAutoApply = false + } + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute) + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply) if err != nil { util.WriteError(r.Context(), err, w) @@ -142,23 +152,31 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { } func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { - if req.Network != nil && req.Domains != nil { + return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId) +} + +func (h *handler) validateRouteUpdate(req api.PutApiRoutesRouteIdJSONRequestBody) error { + return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId) +} + +func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *string, peerGroups *[]string, networkId string) error { + if network != nil && domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } - if req.Network == nil && req.Domains == nil { + if network == nil && domains == nil { return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided") } - if req.Peer == nil && req.PeerGroups == nil { + if peer == nil && peerGroups == nil { return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided") } - if req.Peer != nil && req.PeerGroups != nil { + if peer != nil && peerGroups != nil { return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided") } - if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { + if utf8.RuneCountInString(networkId) > route.MaxNetIDChar || networkId == "" { return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters", route.MaxNetIDChar) } @@ -195,7 +213,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { return } - if err := h.validateRoute(req); err != nil { + if err := h.validateRouteUpdate(req); err != nil { util.WriteError(r.Context(), err, w) return } @@ -205,15 +223,24 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { peerID = *req.Peer } + // Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes) + skipAutoApply := false + if req.SkipAutoApply != nil { + skipAutoApply = *req.SkipAutoApply + } else if req.Network != nil && *req.Network == exitNodeCIDR { + skipAutoApply = false + } + newRoute := &route.Route{ - ID: route.ID(routeID), - NetID: route.NetID(req.NetworkId), - Masquerade: req.Masquerade, - Metric: req.Metric, - Description: req.Description, - Enabled: req.Enabled, - Groups: req.Groups, - KeepRoute: req.KeepRoute, + ID: route.ID(routeID), + NetID: route.NetID(req.NetworkId), + Masquerade: req.Masquerade, + Metric: req.Metric, + Description: req.Description, + Enabled: req.Enabled, + Groups: req.Groups, + KeepRoute: req.KeepRoute, + SkipAutoApply: skipAutoApply, } if req.Domains != nil { @@ -321,18 +348,19 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { } network := serverRoute.Network.String() route := &api.Route{ - Id: string(serverRoute.ID), - Description: serverRoute.Description, - NetworkId: string(serverRoute.NetID), - Enabled: serverRoute.Enabled, - Peer: &serverRoute.Peer, - Network: &network, - Domains: &domains, - NetworkType: serverRoute.NetworkType.String(), - Masquerade: serverRoute.Masquerade, - Metric: serverRoute.Metric, - Groups: serverRoute.Groups, - KeepRoute: serverRoute.KeepRoute, + Id: string(serverRoute.ID), + Description: serverRoute.Description, + NetworkId: string(serverRoute.NetID), + Enabled: serverRoute.Enabled, + Peer: &serverRoute.Peer, + Network: &network, + Domains: &domains, + NetworkType: serverRoute.NetworkType.String(), + Masquerade: serverRoute.Masquerade, + Metric: serverRoute.Metric, + Groups: serverRoute.Groups, + KeepRoute: serverRoute.KeepRoute, + SkipAutoApply: &serverRoute.SkipAutoApply, } if len(serverRoute.PeerGroups) > 0 { diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index fc0e112f7..466a7987f 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -15,13 +15,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/shared/management/domain" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -62,21 +62,22 @@ func initRoutesTestData() *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { - if routeID == existingRouteID { + switch routeID { + case existingRouteID: return baseExistingRoute, nil - } - if routeID == existingRouteID2 { + case existingRouteID2: route := baseExistingRoute.Copy() route.PeerGroups = []string{existingGroupID} return route, nil - } else if routeID == existingRouteID3 { + case existingRouteID3: route := baseExistingRoute.Copy() route.Domains = domain.List{existingDomain} return route, nil + default: + return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) } - return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool, skipAutoApply bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -103,6 +104,7 @@ func initRoutesTestData() *handler { Groups: groups, KeepRoute: keepRoute, AccessControlGroups: accessControlGroups, + SkipAutoApply: skipAutoApply, }, nil }, SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { @@ -190,19 +192,20 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))), + []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"],"skip_auto_apply":false}`, existingPeerID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("192.168.0.0/16"), - Peer: &existingPeerID, - NetworkType: route.IPv4NetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + SkipAutoApply: util.ToPtr(false), }, }, { @@ -210,21 +213,22 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))), + []byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "domainNet", - Network: util.ToPtr("invalid Prefix"), - KeepRoute: true, - Domains: &[]string{existingDomain}, - Peer: &existingPeerID, - NetworkType: route.DomainNetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, + Id: existingRouteID, + Description: "Post", + NetworkId: "domainNet", + Network: util.ToPtr("invalid Prefix"), + KeepRoute: true, + Domains: &[]string{existingDomain}, + Peer: &existingPeerID, + NetworkType: route.DomainNetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + SkipAutoApply: util.ToPtr(false), }, }, { @@ -232,7 +236,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"],\"skip_auto_apply\":false}", existingPeerID, existingGroupID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ @@ -246,6 +250,7 @@ func TestRoutesHandlers(t *testing.T) { Enabled: false, Groups: []string{existingGroupID}, AccessControlGroups: &[]string{existingGroupID}, + SkipAutoApply: util.ToPtr(false), }, }, { @@ -336,60 +341,63 @@ func TestRoutesHandlers(t *testing.T) { name: "Network PUT OK", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, - requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)), + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"is_selected\":true}", existingPeerID, existingGroupID)), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("192.168.0.0/16"), - Peer: &existingPeerID, - NetworkType: route.IPv4NetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + SkipAutoApply: util.ToPtr(false), }, }, { name: "Domains PUT OK", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, - requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)), + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID)), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("invalid Prefix"), - Domains: &[]string{existingDomain}, - Peer: &existingPeerID, - NetworkType: route.DomainNetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, - KeepRoute: true, + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("invalid Prefix"), + Domains: &[]string{existingDomain}, + Peer: &existingPeerID, + NetworkType: route.DomainNetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + KeepRoute: true, + SkipAutoApply: util.ToPtr(false), }, }, { name: "PUT OK when peer_groups provided", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, - requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)), + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"],\"skip_auto_apply\":false}", existingGroupID, existingGroupID)), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("192.168.0.0/16"), - Peer: &emptyString, - PeerGroups: &[]string{existingGroupID}, - NetworkType: route.IPv4NetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("192.168.0.0/16"), + Peer: &emptyString, + PeerGroups: &[]string{existingGroupID}, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + SkipAutoApply: util.ToPtr(false), }, }, { diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index bcd637db4..4e03e5e9b 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,11 +9,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" nbcontext "github.com/netbirdio/netbird/management/server/context" ) @@ -31,6 +31,8 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") addUsersTokensEndpoint(accountManager, router) } @@ -323,17 +325,76 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { } isCurrent := user.ID == currenUserID + return &api.User{ - Id: user.ID, - Name: user.Name, - Email: user.Email, - Role: user.Role, - AutoGroups: autoGroups, - Status: userStatus, - IsCurrent: &isCurrent, - IsServiceUser: &user.IsServiceUser, - IsBlocked: user.IsBlocked, - LastLogin: &user.LastLogin, - Issued: &user.Issued, + Id: user.ID, + Name: user.Name, + Email: user.Email, + Role: user.Role, + AutoGroups: autoGroups, + Status: userStatus, + IsCurrent: &isCurrent, + IsServiceUser: &user.IsServiceUser, + IsBlocked: user.IsBlocked, + LastLogin: &user.LastLogin, + Issued: &user.Issued, + PendingApproval: user.PendingApproval, } } + +// approveUser is a POST request to approve a user that is pending approval +func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + userResponse := toUserResponse(user, userAuth.UserId) + util.WriteJSONObject(r.Context(), w, userResponse) +} + +// rejectUser is a DELETE request to reject a user that is pending approval +func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index f7dc81919..e08004218 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -16,13 +16,13 @@ import ( "github.com/stretchr/testify/require" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -725,3 +725,133 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri } return modules } + +func TestApproveUserEndpoint(t *testing.T) { + adminUser := &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + AccountID: existingAccountID, + AutoGroups: []string{}, + } + + pendingUser := &types.User{ + Id: "pending-user", + Role: types.UserRoleUser, + AccountID: existingAccountID, + Blocked: true, + PendingApproval: true, + AutoGroups: []string{}, + } + + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestingUser *types.User + }{ + { + name: "approve user as admin should return 200", + expectedStatus: 200, + expectedBody: true, + requestingUser: adminUser, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{} + am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + approvedUserInfo := &types.UserInfo{ + ID: pendingUser.Id, + Email: "pending@example.com", + Name: "Pending User", + Role: string(pendingUser.Role), + AutoGroups: []string{}, + IsServiceUser: false, + IsBlocked: false, + PendingApproval: false, + LastLogin: time.Now(), + Issued: types.UserIssuedAPI, + } + return approvedUserInfo, nil + } + + handler := newHandler(am) + router := mux.NewRouter() + router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") + + req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) + require.NoError(t, err) + + userAuth := nbcontext.UserAuth{ + AccountId: existingAccountID, + UserId: tc.requestingUser.Id, + } + ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedBody { + var response api.User + err = json.Unmarshal(rr.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "pending-user", response.Id) + assert.False(t, response.IsBlocked) + assert.False(t, response.PendingApproval) + } + }) + } +} + +func TestRejectUserEndpoint(t *testing.T) { + adminUser := &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + AccountID: existingAccountID, + AutoGroups: []string{}, + } + + tt := []struct { + name string + expectedStatus int + requestingUser *types.User + }{ + { + name: "reject user as admin should return 200", + expectedStatus: 200, + requestingUser: adminUser, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{} + am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + return nil + } + + handler := newHandler(am) + router := mux.NewRouter() + router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE") + + req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) + require.NoError(t, err) + + userAuth := nbcontext.UserAuth{ + AccountId: existingAccountID, + UserId: tc.requestingUser.Id, + } + ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + }) + } +} diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index f221e64a9..6091a4c31 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -13,9 +13,9 @@ import ( "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 2285ed244..d815f5422 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -8,16 +8,15 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/auth" nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 52737e4eb..3fe3fe809 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) const modulePeers = "peers" @@ -47,7 +48,7 @@ func BenchmarkUpdatePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -65,7 +66,7 @@ func BenchmarkUpdatePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) }) } } @@ -82,7 +83,7 @@ func BenchmarkGetOnePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -92,7 +93,7 @@ func BenchmarkGetOnePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) }) } } @@ -109,7 +110,7 @@ func BenchmarkGetAllPeers(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -119,7 +120,7 @@ func BenchmarkGetAllPeers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) }) } } @@ -136,7 +137,7 @@ func BenchmarkDeletePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -146,7 +147,7 @@ func BenchmarkDeletePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 9404c4ee4..36b226db0 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Map to store peers, groups, users, and setupKeys by name @@ -47,7 +48,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -69,7 +70,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) }) } } @@ -86,7 +87,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -109,7 +110,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) }) } } @@ -126,7 +127,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -136,7 +137,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) }) } } @@ -153,7 +154,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -163,7 +164,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) }) } } @@ -180,7 +181,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) b.ResetTimer() @@ -190,7 +191,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 844b3e7a6..2868a20bd 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -18,8 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) const moduleUsers = "users" @@ -46,7 +47,7 @@ func BenchmarkUpdateUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -71,7 +72,7 @@ func BenchmarkUpdateUser(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) }) } } @@ -84,18 +85,18 @@ func BenchmarkGetOneUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) }) } } @@ -110,18 +111,18 @@ func BenchmarkGetAllUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) }) } } @@ -136,7 +137,7 @@ func BenchmarkDeleteUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -147,7 +148,7 @@ func BenchmarkDeleteUsers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 9f04e3c24..1079de4aa 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -15,9 +15,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) func Test_SetupKeys_Create(t *testing.T) { @@ -287,7 +288,7 @@ func Test_SetupKeys_Create(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(user.name+" - "+tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -572,7 +573,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -751,7 +752,7 @@ func Test_SetupKeys_Get(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) @@ -903,7 +904,7 @@ func Test_SetupKeys_GetAll(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) @@ -1087,7 +1088,7 @@ func Test_SetupKeys_Delete(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go new file mode 100644 index 000000000..741f03f18 --- /dev/null +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -0,0 +1,137 @@ +package channel + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/management-integrations/integrations" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + http2 "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" +) + +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) + done := make(chan struct{}) + if validateUpdate { + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + } + + geoMock := &geolocation.Mock{} + validatorMock := server.MockIntegratedValidator{} + proxyController := integrations.NewController(store) + userManager := users.NewManager(store) + permissionsManager := permissions.NewManager(store) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + // @note this is required so that PAT's validate from store, but JWT's are mocked + authManager := auth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, + MarkPATUsedFunc: authManager.MarkPATUsed, + GetPATInfoFunc: authManager.GetPATInfo, + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + peersManager := peers.NewManager(store, permissionsManager) + + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { + userAuth := nbcontext.UserAuth{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + userAuth.UserId = token + userAuth.AccountId = "testAccountId" + userAuth.Domain = "test.com" + userAuth.DomainCategory = "private" + case "otherUserId": + userAuth.UserId = "otherUserId" + userAuth.AccountId = "otherAccountId" + userAuth.Domain = "other.com" + userAuth.DomainCategory = "private" + case "invalidToken": + return userAuth, nil, errors.New("invalid token") + } + + jwtToken := jwt.New(jwt.SigningMethodHS256) + return userAuth, jwtToken, nil +} diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index cc3a2a8f6..2186ecd46 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -3,7 +3,6 @@ package testing_tools import ( "bytes" "context" - "errors" "fmt" "io" "net" @@ -14,32 +13,12 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/users" - - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/groups" - nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/networks" - "github.com/netbirdio/netbird/management/server/networks/resources" - "github.com/netbirdio/netbird/management/server/networks/routers" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" ) @@ -223,11 +202,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta return content, expectedStatus == http.StatusOK } -func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { +func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) { b.Helper() ctx := context.Background() - account, err := am.GetAccount(ctx, TestAccountId) + acc, err := am.GetAccount(ctx, TestAccountId) if err != nil { b.Fatalf("Failed to get account: %v", err) } @@ -243,23 +222,23 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } - account.Peers[peer.ID] = peer + acc.Peers[peer.ID] = peer } // Create users for i := 0; i < users; i++ { user := &types.User{ Id: fmt.Sprintf("olduser-%d", i), - AccountID: account.Id, + AccountID: acc.Id, Role: types.UserRoleUser, } - account.Users[user.Id] = user + acc.Users[user.Id] = user } for i := 0; i < setupKeys; i++ { key := &types.SetupKey{ Id: fmt.Sprintf("oldkey-%d", i), - AccountID: account.Id, + AccountID: acc.Id, AutoGroups: []string{"someGroupID"}, UpdatedAt: time.Now().UTC(), ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), @@ -267,11 +246,11 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Type: "reusable", UsageLimit: 0, } - account.SetupKeys[key.Id] = key + acc.SetupKeys[key.Id] = key } // Create groups and policies - account.Policies = make([]*types.Policy, 0, groups) + acc.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &types.Group{ @@ -282,7 +261,7 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - account.Groups[groupID] = group + acc.Groups[groupID] = group // Create a policy for this group policy := &types.Policy{ @@ -302,10 +281,10 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro }, }, } - account.Policies = append(account.Policies, policy) + acc.Policies = append(acc.Policies, policy) } - account.PostureChecks = []*posture.Checks{ + acc.PostureChecks = []*posture.Checks{ { ID: "PostureChecksAll", Name: "All", @@ -317,52 +296,38 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro }, } - err = am.Store.SaveAccount(context.Background(), account) + store := am.GetStore() + + err = store.SaveAccount(context.Background(), acc) if err != nil { b.Fatalf("Failed to save account: %v", err) } } -func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { +func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { b.Helper() - branch := os.Getenv("GIT_BRANCH") - if branch == "" { - b.Fatalf("environment variable GIT_BRANCH is not set") - } - if recorder.Code != http.StatusOK { b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code) } + EvaluateBenchmarkResults(b, testCase, duration, module, operation) + +} + +func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) { + b.Helper() + + branch := os.Getenv("GIT_BRANCH") + if branch == "" && os.Getenv("CI") == "true" { + b.Fatalf("environment variable GIT_BRANCH is not set") + } + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch) gauge.Set(msPerOp) b.ReportMetric(msPerOp, "ms/op") - -} - -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} - - switch token { - case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": - userAuth.UserId = token - userAuth.AccountId = "testAccountId" - userAuth.Domain = "test.com" - userAuth.DomainCategory = "private" - case "otherUserId": - userAuth.UserId = "otherUserId" - userAuth.AccountId = "otherAccountId" - userAuth.Domain = "other.com" - userAuth.DomainCategory = "private" - case "invalidToken": - return userAuth, nil, errors.New("invalid token") - } - - jwtToken := jwt.New(jwt.SigningMethodHS256) - return userAuth, jwtToken, nil } diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 497f1944f..1eb8434d3 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -16,7 +17,6 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" ) @@ -231,7 +231,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" { return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index f8a0e1210..66c16870b 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -11,12 +11,11 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/telemetry" - - "github.com/golang-jwt/jwt" - "github.com/stretchr/testify/assert" ) type mockHTTPClient struct { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 00d30d645..2f87a9bba 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -2,6 +2,7 @@ package idp import ( "context" + "encoding/base64" "fmt" "io" "net/http" @@ -11,7 +12,6 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" @@ -166,7 +166,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) ( return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 35b86764d..393a39e3e 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -2,6 +2,7 @@ package idp import ( "context" + "encoding/base64" "fmt" "io" "net/http" @@ -10,7 +11,6 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/telemetry" @@ -168,7 +168,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index 07d84058c..c611317ab 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -2,6 +2,7 @@ package idp import ( "context" + "encoding/base64" "fmt" "io" "net/http" @@ -11,7 +12,6 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/telemetry" @@ -158,7 +158,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 343357927..24228346a 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -2,6 +2,7 @@ package idp import ( "context" + "encoding/base64" "errors" "fmt" "io" @@ -12,7 +13,6 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/telemetry" @@ -253,7 +253,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 509022015..21f11bfce 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -46,9 +46,6 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context, groups = []string{} } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { diff --git a/management/server/integrations/port_forwarding/controller.go b/management/server/integrations/port_forwarding/controller.go index 6f062bb12..f2ce81839 100644 --- a/management/server/integrations/port_forwarding/controller.go +++ b/management/server/integrations/port_forwarding/controller.go @@ -3,12 +3,14 @@ package port_forwarding import ( "context" + "github.com/netbirdio/netbird/management/server/peer" nbtypes "github.com/netbirdio/netbird/management/server/types" ) type Controller interface { - SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) - GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) + SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) + GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) + GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) IsPeerInIngressPorts(ctx context.Context, accountID, peerID string) (bool, error) } @@ -19,11 +21,15 @@ func NewControllerMock() *ControllerMock { return &ControllerMock{} } -func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) { +func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) { // noop } -func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) { +func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) { + return make(map[string]*nbtypes.NetworkMap), nil +} + +func (c *ControllerMock) GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) { return make(map[string]*nbtypes.NetworkMap), nil } diff --git a/management/server/loginfilter.go b/management/server/loginfilter.go new file mode 100644 index 000000000..8604af6e2 --- /dev/null +++ b/management/server/loginfilter.go @@ -0,0 +1,160 @@ +package server + +import ( + "hash/fnv" + "math" + "sync" + "time" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +const ( + reconnThreshold = 5 * time.Minute + baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit + reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban + metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer +) + +type lfConfig struct { + reconnThreshold time.Duration + baseBlockDuration time.Duration + reconnLimitForBan int + metaChangeLimit int +} + +func initCfg() *lfConfig { + return &lfConfig{ + reconnThreshold: reconnThreshold, + baseBlockDuration: baseBlockDuration, + reconnLimitForBan: reconnLimitForBan, + metaChangeLimit: metaChangeLimit, + } +} + +type loginFilter struct { + mu sync.RWMutex + cfg *lfConfig + logged map[string]*peerState +} + +type peerState struct { + currentHash uint64 + sessionCounter int + sessionStart time.Time + lastSeen time.Time + isBanned bool + banLevel int + banExpiresAt time.Time + metaChangeCounter int + metaChangeWindowStart time.Time +} + +func newLoginFilter() *loginFilter { + return newLoginFilterWithCfg(initCfg()) +} + +func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter { + return &loginFilter{ + logged: make(map[string]*peerState), + cfg: cfg, + } +} + +func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool { + l.mu.RLock() + defer func() { + l.mu.RUnlock() + }() + state, ok := l.logged[wgPubKey] + if !ok { + return true + } + if state.isBanned && time.Now().Before(state.banExpiresAt) { + return false + } + if metaHash != state.currentHash { + if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit { + return false + } + } + return true +} + +func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) { + now := time.Now() + l.mu.Lock() + defer func() { + l.mu.Unlock() + }() + + state, ok := l.logged[wgPubKey] + + if !ok { + l.logged[wgPubKey] = &peerState{ + currentHash: metaHash, + sessionCounter: 1, + sessionStart: now, + lastSeen: now, + metaChangeWindowStart: now, + metaChangeCounter: 1, + } + return + } + + if state.isBanned && now.After(state.banExpiresAt) { + state.isBanned = false + } + + if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) { + state.banLevel = 0 + } + + if metaHash != state.currentHash { + if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) { + state.metaChangeWindowStart = now + state.metaChangeCounter = 1 + } else { + state.metaChangeCounter++ + } + state.currentHash = metaHash + state.sessionCounter = 1 + state.sessionStart = now + state.lastSeen = now + return + } + + state.sessionCounter++ + if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold { + state.isBanned = true + state.banLevel++ + + backoffFactor := math.Pow(2, float64(state.banLevel-1)) + duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor) + state.banExpiresAt = now.Add(duration) + + state.sessionCounter = 0 + state.sessionStart = now + } + state.lastSeen = now +} + +func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 { + h := fnv.New64a() + + h.Write([]byte(meta.WtVersion)) + h.Write([]byte(meta.OSVersion)) + h.Write([]byte(meta.KernelVersion)) + h.Write([]byte(meta.Hostname)) + h.Write([]byte(meta.SystemSerialNumber)) + h.Write([]byte(pubip)) + + macs := uint64(0) + for _, na := range meta.NetworkAddresses { + for _, r := range na.Mac { + macs += uint64(r) + } + } + + return h.Sum64() + macs +} diff --git a/management/server/loginfilter_test.go b/management/server/loginfilter_test.go new file mode 100644 index 000000000..65782dd9d --- /dev/null +++ b/management/server/loginfilter_test.go @@ -0,0 +1,275 @@ +package server + +import ( + "hash/fnv" + "math" + "math/rand" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func testAdvancedCfg() *lfConfig { + return &lfConfig{ + reconnThreshold: 50 * time.Millisecond, + baseBlockDuration: 100 * time.Millisecond, + reconnLimitForBan: 3, + metaChangeLimit: 2, + } +} + +type LoginFilterTestSuite struct { + suite.Suite + filter *loginFilter +} + +func (s *LoginFilterTestSuite) SetupTest() { + s.filter = newLoginFilterWithCfg(testAdvancedCfg()) +} + +func TestLoginFilterTestSuite(t *testing.T) { + suite.Run(t, new(LoginFilterTestSuite)) +} + +func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.True(s.filter.allowLogin(pubKey, meta)) + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(1, s.filter.logged[pubKey].sessionCounter) +} + +func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + limit := s.filter.cfg.reconnLimitForBan + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + + s.False(s.filter.allowLogin(pubKey, meta)) + s.Require().Contains(s.filter.logged, pubKey) + s.True(s.filter.logged[pubKey].isBanned) +} + +func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + limit := s.filter.cfg.reconnLimitForBan + baseBan := s.filter.cfg.baseBlockDuration + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + s.Require().Contains(s.filter.logged, pubKey) + s.True(s.filter.logged[pubKey].isBanned) + s.Equal(1, s.filter.logged[pubKey].banLevel) + firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond)) + + s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second) + s.filter.logged[pubKey].isBanned = false + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + s.True(s.filter.logged[pubKey].isBanned) + s.Equal(2, s.filter.logged[pubKey].banLevel) + secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1)) + s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond)) +} + +func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.filter.logged[pubKey] = &peerState{ + isBanned: true, + banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)), + } + + s.True(s.filter.allowLogin(pubKey, meta)) + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.False(s.filter.logged[pubKey].isBanned) +} + +func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.filter.logged[pubKey] = &peerState{ + currentHash: meta, + banLevel: 3, + lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration), + } + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(0, s.filter.logged[pubKey].banLevel) +} + +func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() { + pubKey := "PUB_KEY_A" + limit := s.filter.cfg.metaChangeLimit + + for i := range limit { + s.filter.addLogin(pubKey, uint64(i+1)) + } + + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter) + + isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1)) + + s.False(isAllowed, "should block new meta hash after limit is reached") +} + +func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() { + pubKey := "PUB_KEY_A" + meta1 := uint64(1) + meta2 := uint64(2) + meta3 := uint64(3) + + s.filter.addLogin(pubKey, meta1) + s.filter.addLogin(pubKey, meta2) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter) + s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window") + + s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second)) + + s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires") + + s.filter.addLogin(pubKey, meta3) + s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset") +} + +func BenchmarkHashingMethods(b *testing.B) { + meta := nbpeer.PeerSystemMeta{ + WtVersion: "1.25.1", + OSVersion: "Ubuntu 22.04.3 LTS", + KernelVersion: "5.15.0-76-generic", + Hostname: "prod-server-database-01", + SystemSerialNumber: "PC-1234567890", + NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}}, + } + pubip := "8.8.8.8" + + var resultString string + var resultUint uint64 + + b.Run("BuilderString", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultString = builderString(meta, pubip) + } + }) + + b.Run("FnvHashToString", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultString = fnvHashToString(meta, pubip) + } + }) + + b.Run("FnvHashToUint64 - used", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultUint = metaHash(meta, pubip) + } + }) + + _ = resultString + _ = resultUint +} + +func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string { + h := fnv.New64a() + + if len(meta.NetworkAddresses) != 0 { + for _, na := range meta.NetworkAddresses { + h.Write([]byte(na.Mac)) + } + } + + h.Write([]byte(meta.WtVersion)) + h.Write([]byte(meta.OSVersion)) + h.Write([]byte(meta.KernelVersion)) + h.Write([]byte(meta.Hostname)) + h.Write([]byte(meta.SystemSerialNumber)) + h.Write([]byte(pubip)) + + return strconv.FormatUint(h.Sum64(), 16) +} + +func builderString(meta nbpeer.PeerSystemMeta, pubip string) string { + mac := getMacAddress(meta.NetworkAddresses) + estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + + len(pubip) + len(mac) + 6 + + var b strings.Builder + b.Grow(estimatedSize) + + b.WriteString(meta.WtVersion) + b.WriteByte('|') + b.WriteString(meta.OSVersion) + b.WriteByte('|') + b.WriteString(meta.KernelVersion) + b.WriteByte('|') + b.WriteString(meta.Hostname) + b.WriteByte('|') + b.WriteString(meta.SystemSerialNumber) + b.WriteByte('|') + b.WriteString(pubip) + + return b.String() +} + +func getMacAddress(nas []nbpeer.NetworkAddress) string { + if len(nas) == 0 { + return "" + } + macs := make([]string, 0, len(nas)) + for _, na := range nas { + macs = append(macs, na.Mac) + } + return strings.Join(macs, "/") +} + +func BenchmarkLoginFilter_ParallelLoad(b *testing.B) { + filter := newLoginFilterWithCfg(testAdvancedCfg()) + numKeys := 100000 + pubKeys := make([]string, numKeys) + for i := range numKeys { + pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for pb.Next() { + key := pubKeys[r.Intn(numKeys)] + meta := r.Uint64() + + if filter.allowLogin(key, meta) { + filter.addLogin(key, meta) + } + } + }) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2a27bf6a7..cc79082a6 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -61,7 +61,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -95,6 +95,8 @@ type MockAccountManager struct { LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error + ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) + RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() account.ExternalCacheManager @@ -516,9 +518,9 @@ func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userI } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute, isSelected) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } @@ -629,6 +631,20 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") } +func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + if am.ApproveUserFunc != nil { + return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented") +} + +func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + if am.RejectUserFunc != nil { + return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID) + } + return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented") +} + // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if am.GetNameServerGroupFunc != nil { @@ -977,3 +993,10 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } + +func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { + if am.AllowSyncFunc != nil { + return am.AllowSyncFunc(key, hash) + } + return true +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 1ee8805fc..f278e1761 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -37,9 +37,6 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -73,11 +70,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, newNSGroup) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return nil, err @@ -94,9 +91,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if nsGroupToSave == nil { return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } @@ -127,11 +121,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, nsGroupToSave) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -148,9 +142,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -173,11 +164,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil { return err } - return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 2bab0e289..b6706ca45 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -70,9 +70,6 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network network.ID = xid.New().String() - unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) - defer unlock() - err = m.store.SaveNetwork(ctx, network) if err != nil { return nil, fmt.Errorf("failed to save network: %w", err) @@ -104,9 +101,6 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network return nil, status.NewPermissionDeniedError() } - unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) - defer unlock() - _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) if err != nil { return nil, fmt.Errorf("failed to get network: %w", err) @@ -131,9 +125,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return fmt.Errorf("failed to get network: %w", err) } - unlock := m.store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) @@ -167,15 +158,15 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return fmt.Errorf("failed to delete network: %w", err) } + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) + }) + err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } - eventsToStore = append(eventsToStore, func() { - m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) - }) - return nil }) if err != nil { diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index d0b29075b..294f51676 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -108,9 +108,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return nil, fmt.Errorf("failed to create new network resource: %w", err) } - unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) - defer unlock() - var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name) @@ -204,9 +201,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc resource.Domain = domain resource.Prefix = prefix - unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) - defer unlock() - var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) @@ -315,9 +309,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return status.NewPermissionDeniedError() } - unlock := m.store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var events []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID) diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index ca99e4fd1..82cac424a 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -88,9 +88,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } - unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) - defer unlock() - var network *networkTypes.Network err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) @@ -157,9 +154,6 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } - unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) - defer unlock() - var network *networkTypes.Network err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) @@ -203,9 +197,6 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo return status.NewPermissionDeniedError() } - unlock := m.store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - var event func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID) diff --git a/management/server/peer.go b/management/server/peer.go index 16abf2b40..3b9622d09 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -192,9 +192,6 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -461,9 +458,6 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -486,7 +480,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) + peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { return err } @@ -500,10 +494,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { - return fmt.Errorf("failed to remove peer from groups: %w", err) - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -553,7 +543,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin } customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return nil, err @@ -625,6 +615,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found") } + if user.PendingApproval { + return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") + } groupsToAdd = user.AutoGroups opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser @@ -735,13 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s newPeer.DNSLabel = freeLabel newPeer.IP = freeIP - unlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer func() { - if unlock != nil { - unlock() - } - }() - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = transaction.AddPeerToAccount(ctx, newPeer) if err != nil { @@ -793,14 +779,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil }) if err == nil { - unlock() - unlock = nil break } if isUniqueConstraintError(err) { - unlock() - unlock = nil log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err) continue } @@ -959,15 +941,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } } - unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID) - defer unlockAccount() - unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey) - defer func() { - if unlockPeer != nil { - unlockPeer() - } - }() - var peer *nbpeer.Peer var updateRemotePeers bool var isRequiresApproval bool @@ -1048,9 +1021,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } - unlockPeer() - unlockPeer = nil - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1182,7 +1152,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return nil, nil, nil, err @@ -1355,7 +1325,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return @@ -1494,7 +1464,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return @@ -1682,7 +1652,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto } dnsDomain := am.GetDNSDomain(settings) - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index f7140e254..6a6d1c91d 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -24,7 +24,7 @@ type Peer struct { // Meta is a Peer system meta data Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) - Name string + Name string `gorm:"index"` // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // uniqueness index per accountID (check migrations) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 92b240285..b2d5b7e39 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -26,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" @@ -989,19 +990,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -1609,7 +1605,6 @@ func Test_LoginPeer(t *testing.T) { testCases := []struct { name string setupKey string - wireGuardPubKey string expectExtraDNSLabelsMismatch bool extraDNSLabels []string expectLoginError bool @@ -1973,7 +1968,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, - route.Groups, []string{}, true, userID, route.KeepRoute, + route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply, ) require.NoError(t, err) @@ -2388,3 +2383,186 @@ func TestBufferUpdateAccountPeers(t *testing.T) { assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) } + +func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Try to add peer with pending approval user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + peer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + require.Error(t, err) + assert.Contains(t, err.Error(), "user pending approval cannot add peers") +} + +func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create regular user (not pending approval) + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + // Try to add peer with regular user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + peer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + require.NoError(t, err, "Regular user should be able to add peers") +} + +func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Create a peer using AddPeer method for the pending user (simulate existing peer) + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + // Set the user to not be pending initially so peer can be added + pendingUser.Blocked = false + pendingUser.PendingApproval = false + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Add peer using regular flow + newPeer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + WtVersion: "0.28.0", + }, + } + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + require.NoError(t, err) + + // Now set the user back to pending approval after peer was created + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Try to login with pending approval user + login := types.PeerLogin{ + WireGuardPubKey: existingPeer.Key, + UserID: pendingUser.Id, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.LoginPeer(context.Background(), login) + require.Error(t, err) + e, ok := status.FromError(err) + require.True(t, ok, "error is not a gRPC status error") + assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code") +} + +func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create regular user (not pending approval) + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + // Add peer using regular flow for the regular user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + WtVersion: "0.28.0", + }, + } + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + require.NoError(t, err) + + // Try to login with regular user + login := types.PeerLogin{ + WireGuardPubKey: existingPeer.Key, + UserID: regularUser.Id, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.LoginPeer(context.Background(), login) + require.NoError(t, err, "Regular user should be able to login peers") +} diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go index 50e36a880..cb135f4ac 100644 --- a/management/server/peers/manager.go +++ b/management/server/peers/manager.go @@ -18,6 +18,7 @@ type Manager interface { GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) GetPeerAccountID(ctx context.Context, peerID string) (string, error) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) } type managerImpl struct { @@ -61,3 +62,7 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) } + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} diff --git a/management/server/peers/manager_mock.go b/management/server/peers/manager_mock.go index b247a1752..994f8346b 100644 --- a/management/server/peers/manager_mock.go +++ b/management/server/peers/manager_mock.go @@ -79,3 +79,18 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) } + +// GetPeersByGroupIDs mocks base method. +func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs) + ret0, _ := ret[0].([]*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs. +func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) +} diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 0ab244243..891fa59bb 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -54,10 +54,14 @@ func (m *managerImpl) ValidateUserPermissions( return false, status.NewUserNotFoundError(userID) } - if user.IsBlocked() { + if user.IsBlocked() && !user.PendingApproval { return false, status.NewUserBlockedError() } + if user.IsBlocked() && user.PendingApproval { + return false, status.NewUserPendingApprovalError() + } + if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil { return false, err } diff --git a/management/server/policy.go b/management/server/policy.go index d5c66e9f8..3adee6397 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -32,9 +32,6 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - operation := operations.Create if !create { operation = operations.Update @@ -61,17 +58,17 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { - return err - } - saveFunc := transaction.CreatePolicy if isUpdate { action = activity.PolicyUpdated saveFunc = transaction.SavePolicy } - return saveFunc(ctx, policy) + if err = saveFunc(ctx, policy); err != nil { + return err + } + + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return nil, err @@ -88,9 +85,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -113,11 +107,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil { return err } - return transaction.DeletePolicy(ctx, accountID, policyID) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -173,10 +167,22 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } + + // TODO: Refactor to support multiple rules per policy + existingRuleIDs := make(map[string]bool) + for _, rule := range existingPolicy.Rules { + existingRuleIDs[rule.ID] = true + } + + for _, rule := range policy.Rules { + if rule.ID != "" && !existingRuleIDs[rule.ID] { + return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + } + } } else { policy.ID = xid.New().String() policy.AccountID = accountID diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9414b8065..943f2a970 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -32,9 +32,6 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - operation := operations.Create if !create { operation = operations.Update @@ -62,15 +59,19 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { - return err - } - action = activity.PostureCheckUpdated } postureChecks.AccountID = accountID - return transaction.SavePostureChecks(ctx, postureChecks) + if err = transaction.SavePostureChecks(ctx, postureChecks); err != nil { + return err + } + + if isUpdate { + return transaction.IncrementNetworkSerial(ctx, accountID) + } + + return nil }) if err != nil { return nil, err @@ -87,9 +88,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) if err != nil { return status.NewPermissionValidationError(err) @@ -110,11 +108,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.DeletePostureChecks(ctx, accountID, postureChecksID); err != nil { return err } - return transaction.DeletePostureChecks(ctx, accountID, postureChecksID) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err diff --git a/management/server/route.go b/management/server/route.go index b853d9cd6..4510426bb 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -134,10 +134,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -170,6 +167,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri Enabled: enabled, Groups: groups, AccessControlGroups: accessControlGroupIDs, + SkipAutoApply: skipAutoApply, } if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil { @@ -181,11 +179,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.SaveRoute(ctx, newRoute); err != nil { return err } - return transaction.SaveRoute(ctx, newRoute) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return nil, err @@ -202,9 +200,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) if err != nil { return status.NewPermissionValidationError(err) @@ -238,11 +233,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } routeToSave.AccountID = accountID - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.SaveRoute(ctx, routeToSave); err != nil { return err } - return transaction.SaveRoute(ctx, routeToSave) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err @@ -259,9 +254,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -284,11 +276,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil { return err } - return transaction.DeleteRoute(ctx, accountID, string(routeID)) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return fmt.Errorf("failed to delete route %s: %w", routeID, err) @@ -382,15 +374,16 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID func toProtocolRoute(route *route.Route) *proto.Route { return &proto.Route{ - ID: string(route.ID), - NetID: string(route.NetID), - Network: route.Network.String(), - Domains: route.Domains.ToPunycodeList(), - NetworkType: int64(route.NetworkType), - Peer: route.Peer, - Metric: int64(route.Metric), - Masquerade: route.Masquerade, - KeepRoute: route.KeepRoute, + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, + SkipAutoApply: route.SkipAutoApply, } } diff --git a/management/server/route_test.go b/management/server/route_test.go index 6c61fdf9c..aeeeb736b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -69,6 +69,7 @@ func TestCreateRoute(t *testing.T) { enabled bool groups []string accessControlGroups []string + skipAutoApply bool } testCases := []struct { @@ -444,13 +445,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false, true) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false, true) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute, testCase.inputArgs.skipAutoApply) testCase.errFunc(t, err) @@ -1084,7 +1085,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1176,7 +1177,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -2004,7 +2005,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, - route.Groups, []string{}, true, userID, route.KeepRoute, + route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply, ) require.NoError(t, err) @@ -2040,7 +2041,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, - route.Groups, []string{}, true, userID, route.KeepRoute, + route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply, ) require.NoError(t, err) @@ -2076,7 +2077,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { newRoute, err := manager.CreateRoute( context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, - baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, + baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, !baseRoute.SkipAutoApply, ) require.NoError(t, err) baseRoute = *newRoute @@ -2142,7 +2143,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, - newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply, ) require.NoError(t, err) @@ -2182,7 +2183,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, - newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply, ) require.NoError(t, err) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 71915b4a2..8d0509871 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -55,8 +55,6 @@ type SetupKeyUpdateOperation struct { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create) if err != nil { @@ -107,9 +105,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index f27eddb2f..a820d99a9 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -52,7 +52,6 @@ const ( // SqlStore represents an account storage backed by a Sql DB persisted to disk type SqlStore struct { db *gorm.DB - resourceLocks sync.Map globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int @@ -219,44 +218,6 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { return unlock } -// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock -func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID) - - startWait := time.Now() - value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) - mtx := value.(*sync.RWMutex) - mtx.Lock() - log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait)) - startHold := time.Now() - - unlock = func() { - mtx.Unlock() - log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold)) - } - - return unlock -} - -// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock -func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID) - - startWait := time.Now() - value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) - mtx := value.(*sync.RWMutex) - mtx.RLock() - log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait)) - startHold := time.Now() - - unlock = func() { - mtx.RUnlock() - log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold)) - } - - return unlock -} - // Deprecated: Full account operations are no longer supported func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { start := time.Now() @@ -1028,7 +989,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { var account types.Account - result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account) + result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account) if result.Error != nil { return "", status.NewGetAccountFromStoreError(result.Error) } @@ -1513,7 +1474,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI PeerID: peerID, } - err := s.db.WithContext(ctx).Clauses(clause.OnConflict{ + err := s.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, DoNothing: true, }).Create(peer).Error @@ -1528,7 +1489,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI // RemovePeerFromGroup removes a peer from a group func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { - err := s.db.WithContext(ctx). + err := s.db. Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error if err != nil { @@ -1541,7 +1502,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group // RemovePeerFromAllGroups removes a peer from all groups func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { - err := s.db.WithContext(ctx). + err := s.db. Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error if err != nil { @@ -2129,7 +2090,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } @@ -2820,7 +2781,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -2961,3 +2922,22 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i } return nil } + +func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { + if len(groupIDs) == 0 { + return []*nbpeer.Peer{}, nil + } + + var peers []*nbpeer.Peer + peerIDsSubquery := s.db.Model(&types.GroupPeer{}). + Select("DISTINCT peer_id"). + Where("account_id = ? AND group_id IN ?", accountID, groupIDs) + + result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by group IDs") + } + + return peers, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 935b0a595..d40c4664c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3607,3 +3607,113 @@ func intToIPv4(n uint32) net.IP { binary.BigEndian.PutUint32(ip, n) return ip } + +func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group1ID := "test-group-1" + group2ID := "test-group-2" + emptyGroupID := "empty-group" + + peer1 := "cfefqs706sqkneg59g4g" + peer2 := "cfeg6sf06sqkneg59g50" + + tests := []struct { + name string + groupIDs []string + expectedPeers []string + expectedCount int + }{ + { + name: "retrieve peers from single group with multiple peers", + groupIDs: []string{group1ID}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from single group with one peer", + groupIDs: []string{group2ID}, + expectedPeers: []string{peer1}, + expectedCount: 1, + }, + { + name: "retrieve peers from multiple groups (with overlap)", + groupIDs: []string{group1ID, group2ID}, + expectedPeers: []string{peer1, peer2}, // should deduplicate + expectedCount: 2, + }, + { + name: "retrieve peers from existing 'All' group", + groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from empty group", + groupIDs: []string{emptyGroupID}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "retrieve peers from non-existing group", + groupIDs: []string{"non-existing-group"}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "mix of existing and non-existing groups", + groupIDs: []string{group1ID, "non-existing-group"}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + + groups := []*types.Group{ + { + ID: group1ID, + AccountID: accountID, + }, + { + ID: group2ID, + AccountID: accountID, + }, + } + require.NoError(t, store.CreateGroups(ctx, accountID, groups)) + + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID)) + + peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + + if tt.expectedCount > 0 { + actualPeerIDs := make([]string, len(peers)) + for i, peer := range peers { + actualPeerIDs[i] = peer.ID + } + assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs) + + // Verify all returned peers belong to the correct account + for _, peer := range peers { + assert.Equal(t, accountID, peer.AccountID) + } + } + }) + } +} diff --git a/management/server/store/store.go b/management/server/store/store.go index d8566e086..31d027c36 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -136,6 +136,7 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) @@ -168,10 +169,6 @@ type Store interface { GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error - // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock - AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() - // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock - AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index ac6ff2ea8..d4301802f 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -4,20 +4,28 @@ import ( "context" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) +const AccountIDLabel = "account_id" +const HighLatencyThreshold = time.Second * 7 + // GRPCMetrics are gRPC server metrics type GRPCMetrics struct { - meter metric.Meter - syncRequestsCounter metric.Int64Counter - loginRequestsCounter metric.Int64Counter - getKeyRequestsCounter metric.Int64Counter - activeStreamsGauge metric.Int64ObservableGauge - syncRequestDuration metric.Int64Histogram - loginRequestDuration metric.Int64Histogram - channelQueueLength metric.Int64Histogram - ctx context.Context + meter metric.Meter + syncRequestsCounter metric.Int64Counter + syncRequestsBlockedCounter metric.Int64Counter + syncRequestHighLatencyCounter metric.Int64Counter + loginRequestsCounter metric.Int64Counter + loginRequestsBlockedCounter metric.Int64Counter + loginRequestHighLatencyCounter metric.Int64Counter + getKeyRequestsCounter metric.Int64Counter + activeStreamsGauge metric.Int64ObservableGauge + syncRequestDuration metric.Int64Histogram + loginRequestDuration metric.Int64Histogram + channelQueueLength metric.Int64Histogram + ctx context.Context } // NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server @@ -30,6 +38,22 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of sync gRPC requests from blocked peers"), + ) + if err != nil { + return nil, err + } + + syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), + ) + if err != nil { + return nil, err + } + loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -38,6 +62,22 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of login gRPC requests from blocked peers"), + ) + if err != nil { + return nil, err + } + + loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"), + ) + if err != nil { + return nil, err + } + getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"), @@ -83,15 +123,19 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro } return &GRPCMetrics{ - meter: meter, - syncRequestsCounter: syncRequestsCounter, - loginRequestsCounter: loginRequestsCounter, - getKeyRequestsCounter: getKeyRequestsCounter, - activeStreamsGauge: activeStreamsGauge, - syncRequestDuration: syncRequestDuration, - loginRequestDuration: loginRequestDuration, - channelQueueLength: channelQueue, - ctx: ctx, + meter: meter, + syncRequestsCounter: syncRequestsCounter, + syncRequestsBlockedCounter: syncRequestsBlockedCounter, + syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, + loginRequestsCounter: loginRequestsCounter, + loginRequestsBlockedCounter: loginRequestsBlockedCounter, + loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, + getKeyRequestsCounter: getKeyRequestsCounter, + activeStreamsGauge: activeStreamsGauge, + syncRequestDuration: syncRequestDuration, + loginRequestDuration: loginRequestDuration, + channelQueueLength: channelQueue, + ctx: ctx, }, err } @@ -100,6 +144,11 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() { grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1) } +// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers +func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() { + grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) +} + // CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() { grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1) @@ -110,14 +159,25 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() { grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1) } +// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers +func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() { + grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) +} + // CountLoginRequestDuration counts the duration of the login gRPC requests -func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) { +func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + if duration > HighLatencyThreshold { + grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + } } // CountSyncRequestDuration counts the duration of the sync gRPC requests -func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) { +func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + if duration > HighLatencyThreshold { + grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + } } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. diff --git a/management/server/types/account.go b/management/server/types/account.go index 9ac2568a0..a69d3bb08 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -300,9 +300,12 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -1651,3 +1654,24 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } + +// filterZoneRecordsForPeers filters DNS records to only include peers to connect. +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) + peerIPs := make(map[string]struct{}) + + // Add peer's own IP to include its own DNS records + peerIPs[peer.IP.String()] = struct{}{} + + for _, peerToConnect := range peersToConnect { + peerIPs[peerToConnect.IP.String()] = struct{}{} + } + + for _, record := range customZone.Records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f8ab1d627..cd221b590 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,14 +2,17 @@ package types import ( "context" + "fmt" "net" "net/netip" "slices" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } + +func Test_FilterZoneRecordsForPeers(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + customZone nbdns.CustomZone + peersToConnect []*nbpeer.Peer + expectedRecords []nbdns.SimpleRecord + }{ + { + name: "empty peers to connect", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + { + name: "multiple peers multiple records match", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for i := 1; i <= 100; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + peersToConnect: func() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + peers = append(peers, &nbpeer.Peer{ + ID: fmt.Sprintf("peer%d", i), + IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + }) + } + return peers + }(), + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + { + name: "peers with multiple DNS labels", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, + {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + assert.Equal(t, len(tt.expectedRecords), len(result)) + assert.ElementsMatch(t, tt.expectedRecords, result) + }) + } +} diff --git a/management/server/types/network.go b/management/server/types/network.go index f072a4294..ffc019565 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -12,11 +12,11 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" ) const ( diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 56c33da3b..b4afb2f5e 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -83,6 +83,9 @@ type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + // UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator + UserApprovalRequired bool + // IntegratedValidator is the string enum for the integrated validator type IntegratedValidator string // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations @@ -99,6 +102,7 @@ type ExtraSettings struct { func (e *ExtraSettings) Copy() *ExtraSettings { return &ExtraSettings{ PeerApprovalEnabled: e.PeerApprovalEnabled, + UserApprovalRequired: e.UserApprovalRequired, IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups), IntegratedValidator: e.IntegratedValidator, FlowEnabled: e.FlowEnabled, diff --git a/management/server/types/user.go b/management/server/types/user.go index 783fe14da..beb3586df 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -64,6 +64,7 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` + PendingApproval bool `json:"pending_approval"` IntegrationReference integration_reference.IntegrationReference `json:"-"` } @@ -84,6 +85,8 @@ type User struct { PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool + // PendingApproval indicates whether the user requires approval before being activated + PendingApproval bool // LastLogin is the last time the user logged in to IdP LastLogin *time.Time // CreatedAt records the time the user was created @@ -141,16 +144,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { if userData == nil { return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + PendingApproval: u.PendingApproval, }, nil } if userData.ID != u.Id { @@ -163,16 +167,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { } return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + PendingApproval: u.PendingApproval, }, nil } @@ -194,6 +199,7 @@ func (u *User) Copy() *User { ServiceUserName: u.ServiceUserName, PATs: pats, Blocked: u.Blocked, + PendingApproval: u.PendingApproval, LastLogin: u.LastLogin, CreatedAt: u.CreatedAt, Issued: u.Issued, diff --git a/management/server/user.go b/management/server/user.go index ba1835f22..d40d33c6a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -26,9 +26,6 @@ import ( // createServiceUser creates a new service user under the given account. func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -76,9 +73,6 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user // inviteNewUser Invites a USer to a given account and creates reference in datastore func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if am.idpManager == nil { return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } @@ -227,9 +221,6 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err @@ -285,9 +276,6 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if am.idpManager == nil { return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } @@ -328,9 +316,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if tokenName == "" { return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } @@ -379,9 +364,6 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -481,9 +463,6 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists) if err != nil { return nil, err @@ -540,33 +519,46 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUser = result } - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - for _, update := range updates { - if update == nil { - return status.Errorf(status.InvalidArgument, "provided user update is nil") - } + var globalErr error + for _, update := range updates { + if update == nil { + return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - usersToSave = append(usersToSave, updatedUser) - addUserEvents = append(addUserEvents, userEvents...) - peersToExpire = append(peersToExpire, userPeersToExpire...) if userHadPeers { updateAccountPeers = true } + + err = transaction.SaveUser(ctx, updatedUser) + if err != nil { + return fmt.Errorf("failed to save updated user %s: %w", update.Id, err) + } + + usersToSave = append(usersToSave, updatedUser) + addUserEvents = append(addUserEvents, userEvents...) + peersToExpire = append(peersToExpire, userPeersToExpire...) + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err) + if len(updates) == 1 { + return nil, err + } + globalErr = errors.Join(globalErr, err) + // continue when updating multiple users } - return transaction.SaveUsers(ctx, usersToSave) - }) - if err != nil { - return nil, err } - var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates)) + var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave)) userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID) if err != nil { @@ -599,7 +591,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, am.UpdateAccountPeers(ctx, accountID) } - return updatedUsersInfo, nil + return updatedUsersInfo, globalErr } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. @@ -664,7 +656,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } transferredOwnerRole = result - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id) if err != nil { return false, nil, nil, nil, err } @@ -950,6 +942,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + if peer.UserID == "" { + // we do not want to expire peers that are added via setup key + continue + } + if peer.Status.LoginExpired { continue } @@ -968,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service + log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1215,3 +1213,77 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut return userWithPermissions, nil } + +// ApproveUser approves a user that is pending approval +func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.NewUserNotFoundError(targetUserID) + } + + if !user.PendingApproval { + return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) + } + + user.Blocked = false + user.PendingApproval = false + + err = am.Store.SaveUser(ctx, user) + if err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil) + + userInfo, err := am.getUserInfo(ctx, user, accountID) + if err != nil { + return nil, err + } + + return userInfo, nil +} + +// RejectUser rejects a user that is pending approval by deleting them +func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotFoundError(targetUserID) + } + + if !user.PendingApproval { + return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) + } + + err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID) + if err != nil { + return err + } + + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil) + + return nil +} diff --git a/management/server/user_test.go b/management/server/user_test.go index 8ab0c1565..9638559f9 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1746,3 +1746,117 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { return permissions } + +func TestApproveUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account with admin and pending approval user + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create admin user + adminUser := types.NewAdminUser("admin-user") + adminUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), adminUser) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Test successful approval + approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.NoError(t, err) + assert.False(t, approvedUser.IsBlocked) + assert.False(t, approvedUser.PendingApproval) + + // Verify user is updated in store + updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) + require.NoError(t, err) + assert.False(t, updatedUser.Blocked) + assert.False(t, updatedUser.PendingApproval) + + // Test approval of non-pending user should fail + _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not pending approval") + + // Test approval by non-admin should fail + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2.AccountID = account.Id + pendingUser2.Blocked = true + pendingUser2.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser2) + require.NoError(t, err) + + _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) + require.Error(t, err) +} + +func TestRejectUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account with admin and pending approval user + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create admin user + adminUser := types.NewAdminUser("admin-user") + adminUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), adminUser) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Test successful rejection + err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.NoError(t, err) + + // Verify user is deleted from store + _, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) + require.Error(t, err) + + // Test rejection of non-pending user should fail + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not pending approval") + + // Test rejection by non-admin should fail + pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2.AccountID = account.Id + pendingUser2.Blocked = true + pendingUser2.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser2) + require.NoError(t, err) + + err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) + require.Error(t, err) +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 332127660..12219e29b 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { connRemoteAddr := remoteAddr(r) - wsConn, err := websocket.Accept(w, r, nil) + + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) if err != nil { log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err) return diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index 6b8a6f701..4dfea6da1 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -13,10 +13,11 @@ import ( "github.com/pion/logging" "github.com/pion/turn/v3" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay/auth/allow" "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client" - "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/util" ) @@ -100,7 +101,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i), iface.DefaultMTU) err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -110,7 +111,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i), iface.DefaultMTU) err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go index aa0fc662a..e6924061f 100644 --- a/relay/testec2/relay.go +++ b/relay/testec2/relay.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client" ) @@ -70,7 +71,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { ctx := context.Background() clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i), iface.DefaultMTU) if err := c.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -156,7 +157,7 @@ func runReader(conn net.Conn) time.Duration { func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i), iface.DefaultMTU) err := c.Connect(context.Background()) if err != nil { log.Fatalf("failed to connect to server: %s", err) diff --git a/release_files/install.sh b/release_files/install.sh index 856d332cb..5d5349ec4 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -130,36 +130,6 @@ repo_gpgcheck=1 EOF } -install_aur_package() { - INSTALL_PKGS="git base-devel go" - REMOVE_PKGS="" - - # Check if dependencies are installed - for PKG in $INSTALL_PKGS; do - if ! pacman -Q "$PKG" > /dev/null 2>&1; then - # Install missing package(s) - ${SUDO} pacman -S "$PKG" --noconfirm - - # Add installed package for clean up later - REMOVE_PKGS="$REMOVE_PKGS $PKG" - fi - done - - # Build package from AUR - cd /tmp && git clone https://aur.archlinux.org/netbird.git - cd netbird && makepkg -sri --noconfirm - - if ! $SKIP_UI_APP; then - cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git - cd netbird-ui && makepkg -sri --noconfirm - fi - - if [ -n "$REMOVE_PKGS" ]; then - # Clean up the installed packages - ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm - fi -} - prepare_tun_module() { # Create the necessary file structure for /dev/net/tun if [ ! -c /dev/net/tun ]; then @@ -276,12 +246,9 @@ install_netbird() { if ! $SKIP_UI_APP; then ${SUDO} rpm-ostree -y install netbird-ui fi - ;; - pacman) - ${SUDO} pacman -Syy - install_aur_package - # in-line with the docs at https://wiki.archlinux.org/title/Netbird - ${SUDO} systemctl enable --now netbird@main.service + # ensure the service is started after install + ${SUDO} netbird service install || true + ${SUDO} netbird service start || true ;; pkg) # Check if the package is already installed @@ -458,11 +425,7 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" - elif [ -x "$(command -v pacman)" ]; then - PACKAGE_MANAGER="pacman" - echo "The installation will be performed using pacman package manager" fi - else echo "Unable to determine OS type from /etc/os-release" exit 1 diff --git a/route/route.go b/route/route.go index 604f8c60f..08a2d37dc 100644 --- a/route/route.go +++ b/route/route.go @@ -107,6 +107,8 @@ type Route struct { Enabled bool Groups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"` + // SkipAutoApply indicates if this exit node route (0.0.0.0/0) should skip auto-application for client routing + SkipAutoApply bool } // EventMeta returns activity event meta related to the route @@ -136,6 +138,7 @@ func (r *Route) Copy() *Route { Enabled: r.Enabled, Groups: slices.Clone(r.Groups), AccessControlGroups: slices.Clone(r.AccessControlGroups), + SkipAutoApply: r.SkipAutoApply, } return route } @@ -162,7 +165,8 @@ func (r *Route) Equal(other *Route) bool { other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && slices.Equal(r.PeerGroups, other.PeerGroups) && - slices.Equal(r.AccessControlGroups, other.AccessControlGroups) + slices.Equal(r.AccessControlGroups, other.AccessControlGroups) && + other.SkipAutoApply == r.SkipAutoApply } // IsDynamic returns if the route is dynamic, i.e. has domains diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index e38ce9b2f..5736a16e1 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -9,34 +9,30 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - - "github.com/netbirdio/management-integrations/integrations" - - "github.com/netbirdio/netbird/encryption" - mgmt "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/mock_server" - mgmtProto "github.com/netbirdio/netbird/shared/management/proto" - + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmt "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -73,13 +69,31 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) jobManager := mgmt.NewJobManager(nil, store) eventStore := &activity.InMemoryEventStore{} - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + permissionsManagerMock := permissions.NewMockManager(ctrl) + permissionsManagerMock. + EXPECT(). + ValidateUserPermissions( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + ). + Return(true, nil). + AnyTimes() + + peersManger := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager. EXPECT(). @@ -110,6 +124,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { AnyTimes() accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + if err != nil { t.Fatal(err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index f5759ef21..b3fd28e9c 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -18,11 +18,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -53,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go index 56c859652..54a0290d0 100644 --- a/shared/management/client/rest/client_test.go +++ b/shared/management/client/rest/client_test.go @@ -8,8 +8,8 @@ import ( "net/http/httptest" "testing" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/client/rest" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) func withMockClient(callback func(*rest.Client, *http.ServeMux)) { @@ -26,7 +26,7 @@ func ptr[T any, PT *T](x T) PT { func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) { t.Helper() - handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) + handler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) server := httptest.NewServer(handler) defer server.Close() c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp") diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 86082e606..942156ad2 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -278,6 +278,10 @@ components: description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. type: boolean example: true + user_approval_required: + description: Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. + type: boolean + example: false network_traffic_logs_enabled: description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. type: boolean @@ -294,6 +298,7 @@ components: example: true required: - peer_approval_enabled + - user_approval_required - network_traffic_logs_enabled - network_traffic_logs_groups - network_traffic_packet_counter_enabled @@ -355,6 +360,10 @@ components: description: Is true if this user is blocked. Blocked users can't use the system type: boolean example: false + pending_approval: + description: Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. + type: boolean + example: false issued: description: How user was issued by API or Integration type: string @@ -369,6 +378,7 @@ components: - auto_groups - status - is_blocked + - pending_approval UserPermissions: type: object properties: @@ -1462,6 +1472,10 @@ components: items: type: string example: "chacbco6lnnbn6cg5s91" + skip_auto_apply: + description: Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing + type: boolean + example: false required: - id - description @@ -2764,6 +2778,63 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/{userId}/approve: + post: + summary: Approve user + description: Approve a user that is pending approval + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The unique identifier of a user + responses: + '200': + description: Returns the approved user + content: + application/json: + schema: + "$ref": "#/components/schemas/User" + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/{userId}/reject: + delete: + summary: Reject user + description: Reject a user that is pending approval by removing them from the account + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The unique identifier of a user + responses: + '200': + description: User rejected successfully + content: {} + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/users/current: get: summary: Retrieve current user diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 09c1ecb0f..883ce4928 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -284,6 +284,9 @@ type AccountExtraSettings struct { // PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. PeerApprovalEnabled bool `json:"peer_approval_enabled"` + + // UserApprovalRequired Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. + UserApprovalRequired bool `json:"user_approval_required"` } // AccountOnboarding defines model for AccountOnboarding. @@ -1619,6 +1622,9 @@ type Route struct { // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` PeerGroups *[]string `json:"peer_groups,omitempty"` + + // SkipAutoApply Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing + SkipAutoApply *bool `json:"skip_auto_apply,omitempty"` } // RouteRequest defines model for RouteRequest. @@ -1658,6 +1664,9 @@ type RouteRequest struct { // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` PeerGroups *[]string `json:"peer_groups,omitempty"` + + // SkipAutoApply Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing + SkipAutoApply *bool `json:"skip_auto_apply,omitempty"` } // RulePortRange Policy rule affected ports range @@ -1846,8 +1855,11 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` - Permissions *UserPermissions `json:"permissions,omitempty"` + Name string `json:"name"` + + // PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. + PendingApproval bool `json:"pending_approval"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 5ee9565df..97e06f6a1 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v4.24.3 // source: management.proto package proto @@ -1982,6 +1982,7 @@ type PeerConfig struct { Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"` + Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` } func (x *PeerConfig) Reset() { @@ -2058,6 +2059,13 @@ func (x *PeerConfig) GetLazyConnectionEnabled() bool { return false } +func (x *PeerConfig) GetMtu() int32 { + if x != nil { + return x.Mtu + } + return 0 +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -2693,15 +2701,16 @@ type Route struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"` - NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"` - Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"` - Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"` - Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"` - NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"` - Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"` - KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"` + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` + Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"` + NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"` + Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"` + Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"` + Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"` + NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"` + Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"` + KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"` + SkipAutoApply bool `protobuf:"varint,10,opt,name=skipAutoApply,proto3" json:"skipAutoApply,omitempty"` } func (x *Route) Reset() { @@ -2799,6 +2808,13 @@ func (x *Route) GetKeepRoute() bool { return false } +func (x *Route) GetSkipAutoApply() bool { + if x != nil { + return x.SkipAutoApply + } + return false +} + // DNSConfig represents a dns.Update type DNSConfig struct { state protoimpl.MessageState diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 4aaefb5aa..ae5a1b29d 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -303,6 +303,8 @@ message PeerConfig { bool RoutingPeerDnsResolutionEnabled = 5; bool LazyConnectionEnabled = 6; + + int32 mtu = 7; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -439,6 +441,7 @@ message Route { string NetID = 7; repeated string Domains = 8; bool keepRoute = 9; + bool skipAutoApply = 10; } // DNSConfig represents a dns.Update diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 7660174d6..1e914babb 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -42,7 +42,10 @@ const ( // Type is a type of the Error type Type int32 -var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found") +var ( + ErrExtraSettingsNotFound = errors.New("extra settings not found") + ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in") +) // Error is an internal error type Error struct { @@ -110,6 +113,11 @@ func NewUserBlockedError() error { return Errorf(PermissionDenied, "user is blocked") } +// NewUserPendingApprovalError creates a new Error with PermissionDenied type for a blocked user pending approval +func NewUserPendingApprovalError() error { + return Errorf(PermissionDenied, "user is pending approval") +} + // NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer func NewPeerNotRegisteredError() error { return Errorf(Unauthenticated, "peer is not registered") diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 37c9debc2..5dabc5742 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client/dialer" "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" @@ -143,10 +144,12 @@ type Client struct { listenerMutex sync.Mutex stateSubscription *PeersStateSubscription + + mtu uint16 } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect -func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { +func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { hashedID := messages.HashID(peerID) relayLog := log.WithFields(log.Fields{"relay": serverURL}) @@ -155,6 +158,7 @@ func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) connectionURL: serverURL, authTokenStore: authTokenStore, hashedID: hashedID, + mtu: mtu, bufPool: &sync.Pool{ New: func() any { buf := make([]byte, bufferSize) @@ -292,7 +296,16 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{}) + // Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues + var dialers []dialer.DialeFn + if c.mtu > 0 && c.mtu > iface.DefaultMTU { + c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) + dialers = []dialer.DialeFn{ws.Dialer{}} + } else { + dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} + } + + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) conn, err := rd.Dial() if err != nil { return nil, err diff --git a/shared/relay/client/client_test.go b/shared/relay/client/client_test.go index c7c5fbf2b..8fe5f04f4 100644 --- a/shared/relay/client/client_test.go +++ b/shared/relay/client/client_test.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/shared/relay/auth/allow" "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/util" @@ -63,7 +64,7 @@ func TestClient(t *testing.T) { t.Fatalf("failed to start server: %s", err) } t.Log("alice connecting to server") - clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -71,7 +72,7 @@ func TestClient(t *testing.T) { defer clientAlice.Close() t.Log("placeholder connecting to server") - clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder") + clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU) err = clientPlaceHolder.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -79,7 +80,7 @@ func TestClient(t *testing.T) { defer clientPlaceHolder.Close() t.Log("Bob connecting to server") - clientBob := NewClient(serverURL, hmacTokenStore, "bob") + clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -137,7 +138,7 @@ func TestRegistration(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { _ = srv.Shutdown(ctx) @@ -177,7 +178,7 @@ func TestRegistrationTimeout(t *testing.T) { _ = fakeTCPListener.Close() }(fakeTCPListener) - clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice") + clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err == nil { t.Errorf("failed to connect to server: %s", err) @@ -218,7 +219,7 @@ func TestEcho(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -230,7 +231,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(serverURL, hmacTokenStore, idBob) + clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -308,7 +309,7 @@ func TestBindToUnavailabePeer(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -354,13 +355,13 @@ func TestBindReconnect(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - clientBob := NewClient(serverURL, hmacTokenStore, "bob") + clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -382,7 +383,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(serverURL, hmacTokenStore, "alice") + clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -455,13 +456,13 @@ func TestCloseConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - bob := NewClient(serverURL, hmacTokenStore, "bob") + bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) err = bob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -517,13 +518,13 @@ func TestCloseRelayConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - bob := NewClient(serverURL, hmacTokenStore, "bob") + bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) err = bob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice") + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -571,7 +572,7 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(serverURL, hmacTokenStore, idAlice) + relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) if err = relayClient.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -627,7 +628,7 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(serverURL, hmacTokenStore, idAlice) + relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) @@ -678,7 +679,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -690,7 +691,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - clientBob := NewClient(serverURL, hmacTokenStore, idBob) + clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index b496f6a9b..967e18d79 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" quictls "github.com/netbirdio/netbird/shared/relay/tls" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 109651f5d..ef6bd6b3c 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index f3428f255..6220e7f6b 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -63,20 +63,25 @@ type Manager struct { onDisconnectedListeners map[string]*list.List onReconnectedListenerFn func() listenerLock sync.Mutex + + mtu uint16 } // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. -func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { +func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { tokenStore := &relayAuth.TokenStore{} m := &Manager{ ctx: ctx, peerID: peerID, tokenStore: tokenStore, + mtu: mtu, serverPicker: &ServerPicker{ - TokenStore: tokenStore, - PeerID: peerID, + TokenStore: tokenStore, + PeerID: peerID, + MTU: mtu, + ConnectionTimeout: defaultConnectionTimeout, }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), @@ -253,7 +258,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(serverAddress, m.tokenStore, m.peerID) + relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu) err := relayClient.Connect(m.ctx) if err != nil { rt.err = err diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index 674555ff4..f00b35707 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -8,14 +8,15 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" - "github.com/netbirdio/netbird/shared/relay/auth/allow" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/shared/relay/auth/allow" ) func TestEmptyURL(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mgr := NewManager(ctx, nil, "alice") + mgr := NewManager(ctx, nil, "alice", iface.DefaultMTU) err := mgr.Serve() if err == nil { t.Errorf("expected error, got nil") @@ -90,12 +91,12 @@ func TestForeignConn(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice") + clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice", iface.DefaultMTU) if err := clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientBob := NewManager(mCtx, toURL(srvCfg2), "bob") + clientBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU) if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } @@ -197,12 +198,12 @@ func TestForeginConnClose(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob") + mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU) if err := mgrBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - mgr := NewManager(mCtx, toURL(srvCfg1), "alice") + mgr := NewManager(mCtx, toURL(srvCfg1), "alice", iface.DefaultMTU) err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -282,7 +283,7 @@ func TestForeignAutoClose(t *testing.T) { t.Log("connect to server 1.") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, toURL(srvCfg1), idAlice) + mgr := NewManager(mCtx, toURL(srvCfg1), idAlice, iface.DefaultMTU) err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -353,13 +354,13 @@ func TestAutoReconnect(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientBob := NewManager(mCtx, toURL(srvCfg), "bob") + clientBob := NewManager(mCtx, toURL(srvCfg), "bob", iface.DefaultMTU) err = clientBob.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) err = clientAlice.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -428,12 +429,12 @@ func TestNotifierDoubleAdd(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob") + clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob", iface.DefaultMTU) if err = clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice") + clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice", iface.DefaultMTU) if err = clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go index 1cad466ba..39d0ba072 100644 --- a/shared/relay/client/picker.go +++ b/shared/relay/client/picker.go @@ -13,11 +13,8 @@ import ( ) const ( - maxConcurrentServers = 7 -) - -var ( - connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 + defaultConnectionTimeout = 30 * time.Second ) type connResult struct { @@ -27,13 +24,15 @@ type connResult struct { } type ServerPicker struct { - TokenStore *auth.TokenStore - ServerURLs atomic.Value - PeerID string + TokenStore *auth.TokenStore + ServerURLs atomic.Value + PeerID string + MTU uint16 + ConnectionTimeout time.Duration } func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { - ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout) defer cancel() totalServers := len(sp.ServerURLs.Load().([]string)) @@ -70,7 +69,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { log.Infof("try to connecting to relay server: %s", url) - relayClient := NewClient(url, sp.TokenStore, sp.PeerID) + relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU) err := relayClient.Connect(ctx) resultChan <- connResult{ RelayClient: relayClient, diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go index 28167c5ce..fb3fa7375 100644 --- a/shared/relay/client/picker_test.go +++ b/shared/relay/client/picker_test.go @@ -8,15 +8,15 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { - connectionTimeout = 5 * time.Second - + timeout := 5 * time.Second sp := ServerPicker{ - TokenStore: nil, - PeerID: "test", + TokenStore: nil, + PeerID: "test", + ConnectionTimeout: timeout, } sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) - ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) + ctx, cancel := context.WithTimeout(context.Background(), timeout+1) defer cancel() go func() { diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go new file mode 100644 index 000000000..2b584c195 --- /dev/null +++ b/shared/relay/healthcheck/env.go @@ -0,0 +1,24 @@ +package healthcheck + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" +) + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go new file mode 100644 index 000000000..2e14bb8bf --- /dev/null +++ b/shared/relay/healthcheck/env_test.go @@ -0,0 +1,36 @@ +package healthcheck + +import ( + "os" + "testing" +) + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go index b3503d5db..90f795bbe 100644 --- a/shared/relay/healthcheck/receiver.go +++ b/shared/relay/healthcheck/receiver.go @@ -7,10 +7,15 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - heartbeatTimeout = healthCheckInterval + 10*time.Second +const ( + defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second ) +type ReceiverOptions struct { + HeartbeatTimeout time.Duration + AttemptThreshold int +} + // Receiver is a healthcheck receiver // It will listen for heartbeat and check if the heartbeat is not received in a certain time // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work @@ -27,6 +32,23 @@ type Receiver struct { // NewReceiver creates a new healthcheck receiver and start the timer in the background func NewReceiver(log *log.Entry) *Receiver { + opts := ReceiverOptions{ + HeartbeatTimeout: defaultHeartbeatTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewReceiverWithOpts(log, opts) +} + +func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { + heartbeatTimeout := opts.HeartbeatTimeout + if heartbeatTimeout <= 0 { + heartbeatTimeout = defaultHeartbeatTimeout + } + attemptThreshold := opts.AttemptThreshold + if attemptThreshold <= 0 { + attemptThreshold = defaultAttemptThreshold + } + ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ @@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver { ctx: ctx, ctxCancel: ctxCancel, heartbeat: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + attemptThreshold: attemptThreshold, } - go r.waitForHealthcheck() + go r.waitForHealthcheck(heartbeatTimeout) return r } @@ -55,7 +77,7 @@ func (r *Receiver) Stop() { r.ctxCancel() } -func (r *Receiver) waitForHealthcheck() { +func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) { ticker := time.NewTicker(heartbeatTimeout) defer ticker.Stop() defer r.ctxCancel() diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go index 2794159f6..b20cc5124 100644 --- a/shared/relay/healthcheck/receiver_test.go +++ b/shared/relay/healthcheck/receiver_test.go @@ -2,31 +2,18 @@ package healthcheck import ( "context" - "fmt" - "os" - "sync" "testing" "time" log "github.com/sirupsen/logrus" ) -// Mutex to protect global variable access in tests -var testMutex sync.Mutex - func TestNewReceiver(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 5 * time.Second - testMutex.Unlock() - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 5 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) { } func TestNewReceiverNotReceive(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 1 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 1 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) { } func TestNewReceiverAck(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 2 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 2 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() r.Heartbeat() @@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - testMutex.Lock() - originalInterval := healthCheckInterval - originalTimeout := heartbeatTimeout - healthCheckInterval = 1 * time.Second - heartbeatTimeout = healthCheckInterval + 500*time.Millisecond - testMutex.Unlock() + healthCheckInterval := 1 * time.Second - defer func() { - testMutex.Lock() - healthCheckInterval = originalInterval - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := ReceiverOptions{ + HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond, + AttemptThreshold: tc.threshold, + } - receiver := NewReceiver(log.WithField("test_name", tc.name)) + receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts) - testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval if tc.resetCounterOnce { receiver.Heartbeat() - t.Logf("reset counter once") } select { @@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { } t.Fatalf("should have timed out before %s", testTimeout) } - }) } } diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go index 57b3015ec..771e94206 100644 --- a/shared/relay/healthcheck/sender.go +++ b/shared/relay/healthcheck/sender.go @@ -2,52 +2,76 @@ package healthcheck import ( "context" - "os" - "strconv" "time" log "github.com/sirupsen/logrus" ) const ( - defaultAttemptThreshold = 1 - defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" + defaultAttemptThreshold = 1 + + defaultHealthCheckInterval = 25 * time.Second + defaultHealthCheckTimeout = 20 * time.Second ) -var ( - healthCheckInterval = 25 * time.Second - healthCheckTimeout = 20 * time.Second -) +type SenderOptions struct { + HealthCheckInterval time.Duration + HealthCheckTimeout time.Duration + AttemptThreshold int +} // Sender is a healthcheck sender // It will send healthcheck signal to the receiver // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { - log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} + log *log.Entry + healthCheckInterval time.Duration + timeout time.Duration + ack chan struct{} alive bool attemptThreshold int } -// NewSender creates a new healthcheck sender -func NewSender(log *log.Entry) *Sender { +func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender { + if opts.HealthCheckInterval <= 0 { + opts.HealthCheckInterval = defaultHealthCheckInterval + } + if opts.HealthCheckTimeout <= 0 { + opts.HealthCheckTimeout = defaultHealthCheckTimeout + } + if opts.AttemptThreshold <= 0 { + opts.AttemptThreshold = defaultAttemptThreshold + } hc := &Sender{ - log: log, - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - ack: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + log: log, + healthCheckInterval: opts.HealthCheckInterval, + timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout, + ack: make(chan struct{}, 1), + attemptThreshold: opts.AttemptThreshold, } return hc } +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { + opts := SenderOptions{ + HealthCheckInterval: defaultHealthCheckInterval, + HealthCheckTimeout: defaultHealthCheckTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewSenderWithOpts(log, opts) +} + // OnHCResponse sends an acknowledgment signal to the sender func (hc *Sender) OnHCResponse() { select { @@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() { } func (hc *Sender) StartHealthCheck(ctx context.Context) { - ticker := time.NewTicker(healthCheckInterval) + ticker := time.NewTicker(hc.healthCheckInterval) defer ticker.Stop() - timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + timeoutTicker := time.NewTicker(hc.timeout) defer timeoutTicker.Stop() defer close(hc.HealthCheck) @@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { } } } - -func (hc *Sender) getTimeoutTime() time.Duration { - return healthCheckInterval + healthCheckTimeout -} - -func getAttemptThresholdFromEnv() int { - if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { - threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) - if err != nil { - log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) - return defaultAttemptThreshold - } - return int(threshold) - } - return defaultAttemptThreshold -} diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go index 23446366a..122fe0f16 100644 --- a/shared/relay/healthcheck/sender_test.go +++ b/shared/relay/healthcheck/sender_test.go @@ -2,26 +2,23 @@ package healthcheck import ( "context" - "fmt" - "os" "testing" "time" log "github.com/sirupsen/logrus" ) -func TestMain(m *testing.M) { - // override the health check interval to speed up the test - healthCheckInterval = 2 * time.Second - healthCheckTimeout = 100 * time.Millisecond - code := m.Run() - os.Exit(code) -} +var ( + testOpts = SenderOptions{ + HealthCheckInterval: 2 * time.Second, + HealthCheckTimeout: 100 * time.Millisecond, + } +) func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) select { case <-hc.Timeout: - case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond): t.Fatalf("health check is not timed out") } } func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - originalInterval := healthCheckInterval - originalTimeout := healthCheckTimeout - healthCheckInterval = 1 * time.Second - healthCheckTimeout = 500 * time.Millisecond - - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := SenderOptions{ + HealthCheckInterval: 1 * time.Second, + HealthCheckTimeout: 500 * time.Millisecond, + AttemptThreshold: tc.threshold, + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sender := NewSender(log.WithField("test_name", tc.name)) + sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts) senderExit := make(chan struct{}) go func() { sender.StartHealthCheck(ctx) @@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { } }() - testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval select { case <-sender.Timeout: @@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout }) } } - -//nolint:tenv -func TestGetAttemptThresholdFromEnv(t *testing.T) { - tests := []struct { - name string - envValue string - expected int - }{ - {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, - {"Custom attempt threshold when env is set to a valid integer", "3", 3}, - {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.envValue == "" { - os.Unsetenv(defaultAttemptThresholdEnv) - } else { - os.Setenv(defaultAttemptThresholdEnv, tt.envValue) - } - - result := getAttemptThresholdFromEnv() - if result != tt.expected { - t.Fatalf("Expected %d, got %d", tt.expected, result) - } - - os.Unsetenv(defaultAttemptThresholdEnv) - }) - } -} diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 82ab678f4..5ca0c0282 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -16,10 +16,10 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/sharedsock/example/main.go b/sharedsock/example/main.go index 9384d2b1c..da62b276e 100644 --- a/sharedsock/example/main.go +++ b/sharedsock/example/main.go @@ -5,14 +5,16 @@ import ( "os" "os/signal" - "github.com/netbirdio/netbird/sharedsock" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/sharedsock" ) func main() { port := 51820 - rawSock, err := sharedsock.Listen(port, sharedsock.NewIncomingSTUNFilter()) + rawSock, err := sharedsock.Listen(port, sharedsock.NewIncomingSTUNFilter(), iface.DefaultMTU) if err != nil { panic(err) } diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index 1c22e7869..bc2d4d1be 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -36,6 +36,7 @@ type SharedSocket struct { conn4 *socket.Conn conn6 *socket.Conn port int + mtu uint16 routerMux sync.RWMutex router routing.Router packetDemux chan rcvdPacket @@ -56,12 +57,19 @@ var writeSerializerOptions = gopacket.SerializeOptions{ FixLengths: true, } +// Maximum overhead for IP + UDP headers on raw socket +// IPv4: max 60 bytes (20 base + 40 options) + UDP 8 bytes = 68 bytes +// IPv6: 40 bytes + UDP 8 bytes = 48 bytes +// We use the maximum (68) for both IPv4 and IPv6 +const maxIPUDPOverhead = 68 + // Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines -func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { +func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error) { ctx, cancel := context.WithCancel(context.Background()) rawSock := &SharedSocket{ ctx: ctx, cancel: cancel, + mtu: mtu, port: port, packetDemux: make(chan rcvdPacket), } @@ -85,7 +93,7 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err) } var sockErr error @@ -94,7 +102,7 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err) } } @@ -223,7 +231,7 @@ func (s *SharedSocket) Close() error { // read start a read loop for a specific receiver and sends the packet to the packetDemux channel func (s *SharedSocket) read(receiver receiver) { for { - buf := make([]byte, 1500) + buf := make([]byte, s.mtu+maxIPUDPOverhead) n, addr, err := receiver(s.ctx, buf, 0) select { case <-s.ctx.Done(): diff --git a/sharedsock/sock_linux_test.go b/sharedsock/sock_linux_test.go index f5c85119c..a22af461a 100644 --- a/sharedsock/sock_linux_test.go +++ b/sharedsock/sock_linux_test.go @@ -21,7 +21,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) { // create raw socket on a port testingPort := 51821 - rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter(), 1280) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second)) require.NoError(t, err, "unable to set deadline, error: %s", err) @@ -76,7 +76,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) { func TestShouldNotReadNonSTUNPackets(t *testing.T) { testingPort := 39439 - rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter(), 1280) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) defer rawSock.Close() @@ -110,7 +110,7 @@ func TestWriteTo(t *testing.T) { defer udpListener.Close() testingPort := 39440 - rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter(), 1280) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) defer rawSock.Close() @@ -144,7 +144,7 @@ func TestWriteTo(t *testing.T) { } func TestSharedSocket_Close(t *testing.T) { - rawSock, err := Listen(39440, NewIncomingSTUNFilter()) + rawSock, err := Listen(39440, NewIncomingSTUNFilter(), 1280) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) errGrp := errgroup.Group{} diff --git a/sharedsock/sock_nolinux.go b/sharedsock/sock_nolinux.go index a36ef67c6..a92f22edf 100644 --- a/sharedsock/sock_nolinux.go +++ b/sharedsock/sock_nolinux.go @@ -9,6 +9,6 @@ import ( ) // Listen is not supported on other platforms then Linux -func Listen(port int, filter BPFFilter) (net.PacketConn, error) { +func Listen(port int, filter BPFFilter, mtu uint16) (net.PacketConn, error) { return nil, fmt.Errorf("not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 2e89b491a..1d76fa4e4 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + // nolint:gosec _ "net/http/pprof" "strings" diff --git a/signal/peer/peer.go b/signal/peer/peer.go index f21c95a41..c9dd60fc0 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -5,10 +5,16 @@ import ( "sync" "time" + "errors" + log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/signal/metrics" +) + +var ( + ErrPeerAlreadyRegistered = errors.New("peer already registered") ) // Peer representation of a connected Peer @@ -23,15 +29,18 @@ type Peer struct { // registration time RegisteredAt time.Time + + Cancel context.CancelFunc } // NewPeer creates a new instance of a connected Peer -func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { +func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer { return &Peer{ Id: id, Stream: stream, StreamID: time.Now().UnixNano(), RegisteredAt: time.Now(), + Cancel: cancel, } } @@ -69,20 +78,24 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool { } // Register registers peer in the registry -func (registry *Registry) Register(peer *Peer) { +func (registry *Registry) Register(peer *Peer) error { start := time.Now() - registry.regMutex.Lock() - defer registry.regMutex.Unlock() - // can be that peer already exists, but it is fine (e.g. reconnect) p, loaded := registry.Peers.LoadOrStore(peer.Id, peer) if loaded { pp := p.(*Peer) - log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", - peer.Id, peer.StreamID, pp.StreamID) - registry.Peers.Store(peer.Id, peer) - return + if peer.StreamID > pp.StreamID { + log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", + peer.Id, peer.StreamID, pp.StreamID) + if swapped := registry.Peers.CompareAndSwap(peer.Id, pp, peer); !swapped { + return registry.Register(peer) + } + pp.Cancel() + log.Debugf("peer re-registered [%s]", peer.Id) + return nil + } + return ErrPeerAlreadyRegistered } log.Debugf("peer registered [%s]", peer.Id) @@ -92,22 +105,13 @@ func (registry *Registry) Register(peer *Peer) { registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) registry.metrics.Registrations.Add(context.Background(), 1) + + return nil } // Deregister Peer from the Registry (usually once it disconnects) func (registry *Registry) Deregister(peer *Peer) { - registry.regMutex.Lock() - defer registry.regMutex.Unlock() - - p, loaded := registry.Peers.LoadAndDelete(peer.Id) - if loaded { - pp := p.(*Peer) - if peer.StreamID < pp.StreamID { - registry.Peers.Store(peer.Id, p) - log.Debugf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.", - peer.Id, pp.StreamID, peer.StreamID) - return - } + if deleted := registry.Peers.CompareAndDelete(peer.Id, peer); deleted { registry.metrics.ActivePeers.Add(context.Background(), -1) log.Debugf("peer deregistered [%s]", peer.Id) registry.metrics.Deregistrations.Add(context.Background(), 1) diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go index fb85fedda..6b7976eb4 100644 --- a/signal/peer/peer_test.go +++ b/signal/peer/peer_test.go @@ -1,13 +1,18 @@ package peer import ( + "context" + "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" ) @@ -19,12 +24,16 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) peerID := "peer" - olderPeer := NewPeer(peerID, nil) - r.Register(olderPeer) + _, cancel1 := context.WithCancel(context.Background()) + olderPeer := NewPeer(peerID, nil, cancel1) + err = r.Register(olderPeer) + require.NoError(t, err) time.Sleep(time.Nanosecond) - newerPeer := NewPeer(peerID, nil) - r.Register(newerPeer) + _, cancel2 := context.WithCancel(context.Background()) + newerPeer := NewPeer(peerID, nil, cancel2) + err = r.Register(newerPeer) + require.NoError(t, err) registered, _ := r.Get(olderPeer.Id) assert.NotNil(t, registered, "peer can't be nil") @@ -59,10 +68,14 @@ func TestRegistry_Register(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - peer1 := NewPeer("test_peer_1", nil) - peer2 := NewPeer("test_peer_2", nil) - r.Register(peer1) - r.Register(peer2) + _, cancel1 := context.WithCancel(context.Background()) + peer1 := NewPeer("test_peer_1", nil, cancel1) + _, cancel2 := context.WithCancel(context.Background()) + peer2 := NewPeer("test_peer_2", nil, cancel2) + err = r.Register(peer1) + require.NoError(t, err) + err = r.Register(peer2) + require.NoError(t, err) if _, ok := r.Get("test_peer_1"); !ok { t.Errorf("expected test_peer_1 not found in the registry") @@ -78,10 +91,14 @@ func TestRegistry_Deregister(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - peer1 := NewPeer("test_peer_1", nil) - peer2 := NewPeer("test_peer_2", nil) - r.Register(peer1) - r.Register(peer2) + _, cancel1 := context.WithCancel(context.Background()) + peer1 := NewPeer("test_peer_1", nil, cancel1) + _, cancel2 := context.WithCancel(context.Background()) + peer2 := NewPeer("test_peer_2", nil, cancel2) + err = r.Register(peer1) + require.NoError(t, err) + err = r.Register(peer2) + require.NoError(t, err) r.Deregister(peer1) @@ -94,3 +111,213 @@ func TestRegistry_Deregister(t *testing.T) { } } + +func TestRegistry_MultipleRegister_Concurrency(t *testing.T) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + + numGoroutines := 1000 + + ids := make(chan int64, numGoroutines) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + peerID := "peer-concurrent" + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + _, cancel := context.WithCancel(context.Background()) + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + ids <- peer.StreamID + }(i) + } + + wg.Wait() + close(ids) + maxId := int64(0) + for id := range ids { + maxId = max(maxId, id) + } + + peer, ok := registry.Get(peerID) + require.True(t, ok, "expected peer to be registered") + require.Equal(t, maxId, peer.StreamID, "expected the highest StreamID to be registered") +} + +func Benchmark_MultipleRegister_Concurrency(b *testing.B) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(b, err) + + numGoroutines := 1000 + + var wg sync.WaitGroup + peerID := "peer-concurrent" + _, cancel := context.WithCancel(context.Background()) + b.Run("multiple-register", func(b *testing.B) { + registry := NewRegistry(metrics) + b.ResetTimer() + for j := 0; j < b.N; j++ { + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + }(i) + } + wg.Wait() + } + }) +} + +func TestRegistry_MultipleDeregister_Concurrency(t *testing.T) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + + numGoroutines := 1000 + + ids := make(chan int64, numGoroutines) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + peerID := "peer-concurrent" + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + _, cancel := context.WithCancel(context.Background()) + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + ids <- peer.StreamID + registry.Deregister(peer) + }(i) + } + + wg.Wait() + close(ids) + maxId := int64(0) + for id := range ids { + maxId = max(maxId, id) + } + + _, ok := registry.Get(peerID) + require.False(t, ok, "expected peer to be deregistered") +} + +func Benchmark_MultipleDeregister_Concurrency(b *testing.B) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(b, err) + + numGoroutines := 1000 + + var wg sync.WaitGroup + peerID := "peer-concurrent" + _, cancel := context.WithCancel(context.Background()) + b.Run("register-deregister", func(b *testing.B) { + registry := NewRegistry(metrics) + b.ResetTimer() + for j := 0; j < b.N; j++ { + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + time.Sleep(time.Nanosecond) + registry.Deregister(peer) + }(i) + } + wg.Wait() + } + }) +} + +type mockConnectStreamServer struct { + grpc.ServerStream + ctx context.Context +} + +func (m *mockConnectStreamServer) Context() context.Context { + return m.ctx +} + +func (m *mockConnectStreamServer) SendHeader(md metadata.MD) error { + return nil +} + +func (m *mockConnectStreamServer) Send(msg *proto.EncryptedMessage) error { + return nil +} + +func (m *mockConnectStreamServer) Recv() (*proto.EncryptedMessage, error) { + <-m.ctx.Done() + return nil, m.ctx.Err() +} + +func TestReconnectHandling(t *testing.T) { + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + peerID := "test-peer-reconnect" + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + stream1 := &mockConnectStreamServer{ctx: ctx1} + peer1 := NewPeer(peerID, stream1, cancel1) + + err = registry.Register(peer1) + require.NoError(t, err, "first registration should succeed") + + p, found := registry.Get(peerID) + require.True(t, found, "peer should be found in the registry") + require.Equal(t, peer1.StreamID, p.StreamID, "StreamID of registered peer should match") + + time.Sleep(time.Nanosecond) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + stream2 := &mockConnectStreamServer{ctx: ctx2} + peer2 := NewPeer(peerID, stream2, cancel2) + + err = registry.Register(peer2) + require.NoError(t, err, "reconnect registration should succeed") + + select { + case <-ctx1.Done(): + require.ErrorIs(t, ctx1.Err(), context.Canceled, "context of old stream should be canceled after successful reconnection") + case <-time.After(100 * time.Millisecond): + t.Fatal("context of old stream was not canceled after reconnection") + } + + p, found = registry.Get(peerID) + require.True(t, found) + require.Equal(t, peer2.StreamID, p.StreamID, "registered peer should have the new StreamID after reconnection") + + ctx3, cancel3 := context.WithCancel(context.Background()) + defer cancel3() + stream3 := &mockConnectStreamServer{ctx: ctx3} + stalePeer := NewPeer(peerID, stream3, cancel3) + stalePeer.StreamID = peer1.StreamID + + err = registry.Register(stalePeer) + require.ErrorIs(t, err, ErrPeerAlreadyRegistered, "reconnecting with an old StreamID should return an error") + + p, found = registry.Get(peerID) + require.True(t, found) + require.Equal(t, peer2.StreamID, p.StreamID, "active peer should still be the one with the latest StreamID") + + select { + case <-ctx2.Done(): + t.Fatal("context of the new stream should not be canceled after trying to register with an old StreamID") + default: + } +} diff --git a/signal/server/signal.go b/signal/server/signal.go index 8ae14822b..47f01edae 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -2,7 +2,9 @@ package server import ( "context" + "errors" "fmt" + "os" "time" log "github.com/sirupsen/logrus" @@ -15,9 +17,9 @@ import ( "github.com/netbirdio/signal-dispatcher/dispatcher" + "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" - "github.com/netbirdio/netbird/shared/signal/proto" ) const ( @@ -27,6 +29,8 @@ const ( labelTypeNotRegistered = "not_registered" labelTypeStream = "stream" labelTypeMessage = "message" + labelTypeTimeout = "timeout" + labelTypeDisconnected = "disconnected" labelError = "error" labelErrorMissingId = "missing_id" @@ -37,6 +41,12 @@ const ( labelRegistrationStatus = "status" labelRegistrationFound = "found" labelRegistrationNotFound = "not_found" + + sendTimeout = 10 * time.Second +) + +var ( + ErrPeerRegisteredAgain = errors.New("peer registered again") ) // Server an instance of a Signal server @@ -45,6 +55,10 @@ type Server struct { proto.UnimplementedSignalExchangeServer dispatcher *dispatcher.Dispatcher metrics *metrics.AppMetrics + + successHeader metadata.MD + + sendTimeout time.Duration } // NewServer creates a new Signal server @@ -59,10 +73,19 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating dispatcher: %v", err) } + sTimeout := sendTimeout + to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT") + if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 { + log.Trace("using custom send timeout ", parsed) + sTimeout = parsed + } + s := &Server{ - dispatcher: d, - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, + dispatcher: d, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, + successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), + sendTimeout: sTimeout, } return s, nil @@ -82,7 +105,8 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - p, err := s.RegisterPeer(stream) + ctx, cancel := context.WithCancel(context.Background()) + p, err := s.RegisterPeer(stream, cancel) if err != nil { return err } @@ -90,8 +114,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) defer s.DeregisterPeer(p) // needed to confirm that the peer has been registered so that the client can proceed - header := metadata.Pairs(proto.HeaderRegistered, "1") - err = stream.SendHeader(header) + err = stream.SendHeader(s.successHeader) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader))) return err @@ -99,27 +122,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) - <-stream.Context().Done() - log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) - return nil + select { + case <-stream.Context().Done(): + log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) + return nil + case <-ctx.Done(): + return ErrPeerRegisteredAgain + } } -func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { +func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) { log.Debugf("registering new peer") - meta, hasMeta := metadata.FromIncomingContext(stream.Context()) - if !hasMeta { - s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) - return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") - } - - id, found := meta[proto.HeaderId] - if !found { + id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId) + if id == nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId) } - p := peer.NewPeer(id[0], stream) - s.registry.Register(p) + p := peer.NewPeer(id[0], stream, cancel) + if err := s.registry.Register(p); err != nil { + return nil, err + } err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration))) @@ -131,8 +154,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) ( func (s *Server) DeregisterPeer(p *peer.Peer) { log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) - s.registry.Deregister(p) s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) + s.registry.Deregister(p) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { @@ -145,7 +168,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM if !found { s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) + log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) // todo respond to the sender? return } @@ -153,16 +176,34 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) start := time.Now() - // forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - // todo respond to the sender? - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - return - } + sendResultChan := make(chan error, 1) + go func() { + select { + case sendResultChan <- dstPeer.Stream.Send(msg): + return + case <-dstPeer.Stream.Context().Done(): + return + } + }() - // in milliseconds - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(ctx, 1) - s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) + select { + case err := <-sendResultChan: + if err != nil { + log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + return + } + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) + s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) + + case <-dstPeer.Stream.Context().Done(): + log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected))) + + case <-time.After(s.sendTimeout): + dstPeer.Cancel() // cancel the peer context to trigger deregistration + log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout))) + } } diff --git a/util/net/conn.go b/util/net/conn.go deleted file mode 100644 index 26693f841..000000000 --- a/util/net/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !ios - -package net - -import ( - "net" - - log "github.com/sirupsen/logrus" -) - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} diff --git a/util/net/dial.go b/util/net/dial.go deleted file mode 100644 index 595311492..000000000 --- a/util/net/dial.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build !ios - -package net - -import ( - "fmt" - "net" - - log "github.com/sirupsen/logrus" -) - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go deleted file mode 100644 index 1659b6220..000000000 --- a/util/net/dialer_dial.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHooks removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - log.Debugf("Dialing %s %s", network, address) - - if CustomRoutingDisabled() { - return d.Dialer.DialContext(ctx, network, address) - } - - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := callDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go deleted file mode 100644 index 8c57ebbaa..000000000 --- a/util/net/dialer_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (d *Dialer) init() { - // implemented on Linux and Android only -} diff --git a/util/net/env_generic.go b/util/net/env_generic.go deleted file mode 100644 index 6d142a838..000000000 --- a/util/net/env_generic.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !linux || android - -package net - -func Init() { - // nothing to do on non-linux -} - -func AdvancedRouting() bool { - // non-linux currently doesn't support advanced routing - return false -} diff --git a/util/net/listen.go b/util/net/listen.go deleted file mode 100644 index 3ae8a9435..000000000 --- a/util/net/listen.go +++ /dev/null @@ -1,37 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go deleted file mode 100644 index 80f6f7f1a..000000000 --- a/util/net/listener_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (l *ListenerConfig) init() { - // implemented on Linux and Android only -} diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go deleted file mode 100644 index 4060ab49a..000000000 --- a/util/net/listener_listen.go +++ /dev/null @@ -1,205 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. -type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc - listenerAddressRemoveHooksMutex sync.RWMutex - listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. -func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) -} - -// RemoveListenerHooks removes all listener hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil - - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - if CustomRoutingDisabled() { - return l.ListenConfig.ListenPacket(ctx, network, address) - } - - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -// RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *PacketConn) RemoveAddress(addr string) { - if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { - return - } - - ipStr, _, err := net.SplitHostPort(addr) - if err != nil { - log.Errorf("Error splitting IP address and port: %v", err) - return - } - - ipAddr, err := netip.ParseAddr(ipStr) - if err != nil { - log.Errorf("Error parsing IP address %s: %v", ipStr, err) - return - } - - prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) - - listenerAddressRemoveHooksMutex.RLock() - defer listenerAddressRemoveHooksMutex.RUnlock() - - for _, hook := range listenerAddressRemoveHooks { - if err := hook(c.ID, prefix); err != nil { - log.Errorf("Error executing listener address remove hook: %v", err) - } - } -} - - -// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality -func WrapPacketConn(conn net.PacketConn) *PacketConn { - return &PacketConn{ - PacketConn: conn, - ID: GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -}