diff --git a/.github/workflows/docs-ack.yml b/.github/workflows/docs-ack.yml index f11142a36..9116be8c7 100644 --- a/.github/workflows/docs-ack.yml +++ b/.github/workflows/docs-ack.yml @@ -16,29 +16,19 @@ jobs: steps: - name: Read PR body id: body - shell: bash run: | - set -euo pipefail - BODY_B64=$(jq -r '.pull_request.body // "" | @base64' "$GITHUB_EVENT_PATH") - { - echo "body_b64=$BODY_B64" - } >> "$GITHUB_OUTPUT" + BODY=$(jq -r '.pull_request.body // ""' "$GITHUB_EVENT_PATH") + echo "body<> $GITHUB_OUTPUT + echo "$BODY" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT - name: Validate checkbox selection id: validate - shell: bash - env: - BODY_B64: ${{ steps.body.outputs.body_b64 }} run: | - set -euo pipefail - if ! body="$(printf '%s' "$BODY_B64" | base64 -d)"; then - echo "::error::Failed to decode PR body from base64. Data may be corrupted or missing." - exit 1 - fi - - added_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*I added/updated documentation' | wc -l | tr -d '[:space:]' || true) - noneed_checked=$(printf '%s' "$body" | grep -Ei '^[[:space:]]*-\s*\[x\]\s*Documentation is \*\*not needed\*\*' | wc -l | tr -d '[:space:]' || true) + body='${{ steps.body.outputs.body }}' + added_checked=$(printf "%s" "$body" | grep -E '^- \[x\] I added/updated documentation' -i | wc -l | tr -d ' ') + noneed_checked=$(printf "%s" "$body" | grep -E '^- \[x\] Documentation is \*\*not needed\*\*' -i | wc -l | tr -d ' ') if [ "$added_checked" -eq 1 ] && [ "$noneed_checked" -eq 1 ]; then echo "::error::Choose exactly one: either 'docs added' OR 'not needed'." @@ -51,35 +41,30 @@ jobs: fi if [ "$added_checked" -eq 1 ]; then - echo "mode=added" >> "$GITHUB_OUTPUT" + echo "mode=added" >> $GITHUB_OUTPUT else - echo "mode=noneed" >> "$GITHUB_OUTPUT" + echo "mode=noneed" >> $GITHUB_OUTPUT fi - name: Extract docs PR URL (when 'docs added') if: steps.validate.outputs.mode == 'added' id: extract - shell: bash - env: - BODY_B64: ${{ steps.body.outputs.body_b64 }} run: | - set -euo pipefail - body="$(printf '%s' "$BODY_B64" | base64 -d)" + body='${{ steps.body.outputs.body }}' # Strictly require HTTPS and that it's a PR in netbirdio/docs - # e.g., https://github.com/netbirdio/docs/pull/1234 - url="$(printf '%s' "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true)" + # Examples accepted: + # https://github.com/netbirdio/docs/pull/1234 + url=$(printf "%s" "$body" | grep -Eo 'https://github\.com/netbirdio/docs/pull/[0-9]+' | head -n1 || true) - if [ -z "${url:-}" ]; then + if [ -z "$url" ]; then echo "::error::You checked 'docs added' but didn't include a valid HTTPS PR link to netbirdio/docs (e.g., https://github.com/netbirdio/docs/pull/1234)." exit 1 fi - pr_number="$(printf '%s' "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#')" - { - echo "url=$url" - echo "pr_number=$pr_number" - } >> "$GITHUB_OUTPUT" + pr_number=$(echo "$url" | sed -E 's#.*/pull/([0-9]+)$#\1#') + echo "url=$url" >> $GITHUB_OUTPUT + echo "pr_number=$pr_number" >> $GITHUB_OUTPUT - name: Verify docs PR exists (and is open or merged) if: steps.validate.outputs.mode == 'added' diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba36c013b..0013833c4 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "-race" + raceFlag: "" runs-on: ubuntu-22.04 steps: - name: Install Go @@ -382,32 +382,6 @@ jobs: store: [ 'sqlite', 'postgres' ] runs-on: ubuntu-22.04 steps: - - name: Create Docker network - run: docker network create promnet - - - name: Start Prometheus Pushgateway - run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway - - - name: Start Prometheus (for Pushgateway forwarding) - run: | - echo ' - global: - scrape_interval: 15s - scrape_configs: - - job_name: "pushgateway" - static_configs: - - targets: ["pushgateway:9091"] - remote_write: - - url: ${{ secrets.GRAFANA_URL }} - basic_auth: - username: ${{ secrets.GRAFANA_USER }} - password: ${{ secrets.GRAFANA_API_KEY }} - ' > prometheus.yml - - docker run -d --name prometheus --network promnet \ - -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ - -p 9090:9090 \ - prom/prometheus - name: Install Go uses: actions/setup-go@v5 with: @@ -454,10 +428,9 @@ jobs: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \ CI=true \ - GIT_BRANCH=${{ github.ref_name }} \ go test -tags devcert -run=^$ -bench=. \ - -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ - -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http) + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 20m ./management/... ./shared/management/... api_benchmark: name: "Management / Benchmark (API)" @@ -548,7 +521,7 @@ jobs: -run=^$ \ -bench=. \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ - -timeout 20m ./management/server/http/... + -timeout 20m ./management/... ./shared/management/... api_integration_test: name: "Management / Integration" @@ -598,4 +571,4 @@ jobs: CI=true \ go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/server/http/... + -timeout 20m ./management/... ./shared/management/... \ No newline at end of file diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2083c0721..d9ff0a84b 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -63,7 +63,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - name: test run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e9741f541..7be52259b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.23" + SIGN_PIPE_VER: "v0.0.22" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" diff --git a/README.md b/README.md index 2c5ee2ab6..ea7655869 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ -


@@ -53,7 +52,7 @@ ### Open Source Network Security in a Single Platform -https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 +centralized-network-management 1 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..e19a09909 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,7 +18,7 @@ ENV \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="5" + NB_ENTRYPOINT_LOGIN_TIMEOUT="1" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index 218817e62..678f5d9d5 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,7 +4,6 @@ package android import ( "context" - "os" "slices" "sync" @@ -19,7 +18,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/util/net" ) // ConnectionListener export internal Listener for mobile @@ -84,8 +83,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { - exportEnvList(envList) +func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -120,8 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { - exportEnvList(envList) +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -252,14 +249,3 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } - -func exportEnvList(list *EnvList) { - if list == nil { - return - } - for k, v := range list.AllItems() { - if err := os.Setenv(k, v); err != nil { - log.Errorf("could not set env variable %s: %v", k, err) - } - } -} diff --git a/client/android/env_list.go b/client/android/env_list.go deleted file mode 100644 index 04122300a..000000000 --- a/client/android/env_list.go +++ /dev/null @@ -1,32 +0,0 @@ -package android - -import "github.com/netbirdio/netbird/client/internal/peer" - -var ( - // EnvKeyNBForceRelay Exported for Android java client - EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay -) - -// EnvList wraps a Go map for export to Java -type EnvList struct { - data map[string]string -} - -// NewEnvList creates a new EnvList -func NewEnvList() *EnvList { - return &EnvList{data: make(map[string]string)} -} - -// Put adds a key-value pair -func (el *EnvList) Put(key, value string) { - el.data[key] = value -} - -// Get retrieves a value by key -func (el *EnvList) Get(key string) string { - return el.data[key] -} - -func (el *EnvList) AllItems() map[string]string { - return el.data -} diff --git a/client/android/login.go b/client/android/login.go index 0df78dbc3..d8ac645e2 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -33,7 +33,6 @@ type ErrListener interface { // the backend want to show an url for the user type URLOpener interface { Open(string) - OnLoginSuccess() } // Auth can register or login new client @@ -182,11 +181,6 @@ func (a *Auth) login(urlOpener URLOpener) error { err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) - - if err == nil { - go urlOpener.OnLoginSuccess() - } - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return nil } diff --git a/client/cmd/down.go b/client/cmd/down.go index 17c152d22..3ce51c678 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -27,7 +27,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..92de6abdb 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, } // update host's static platform and system information - system.UpdateStaticInfoAsync() + system.UpdateStaticInfo() configFilePath, err := activeProf.FilePath() if err != nil { diff --git a/client/cmd/root.go b/client/cmd/root.go index 11e5228f1..8aa0d7c89 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -39,7 +39,6 @@ const ( extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" enableLazyConnectionFlag = "enable-lazy-connection" - mtuFlag = "mtu" ) var ( @@ -73,7 +72,6 @@ var ( anonymizeFlag bool dnsRouteInterval time.Duration lazyConnEnabled bool - mtu uint16 profilesDisabled bool updateSettingsDisabled bool @@ -231,7 +229,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*10) + ctx, cancel := context.WithTimeout(ctx, time.Second*3) defer cancel() return grpc.DialContext( diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go index ce95786dd..844eea853 100644 --- a/client/cmd/root_test.go +++ b/client/cmd/root_test.go @@ -54,7 +54,6 @@ func TestSetFlagsFromEnvVars(t *testing.T) { cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") - cmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface") t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") t.Setenv("NB_INTERFACE_NAME", "test-name") diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0545ce6b7..50fb35d5e 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error { log.Info("starting NetBird service") //nolint // Collect static system and platform information - system.UpdateStaticInfoAsync() + system.UpdateStaticInfo() // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer() diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 729b191c3..42cca1a9b 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -9,26 +9,29 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "google.golang.org/grpc" - "github.com/netbirdio/management-integrations/integrations" - clientProto "github.com/netbirdio/netbird/client/proto" - client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" - mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + + "github.com/netbirdio/netbird/util" + + "google.golang.org/grpc" + + "github.com/netbirdio/management-integrations/integrations" + + clientProto "github.com/netbirdio/netbird/client/proto" + client "github.com/netbirdio/netbird/client/server" + mgmt "github.com/netbirdio/netbird/management/server" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto" sig "github.com/netbirdio/netbird/signal/server" - "github.com/netbirdio/netbird/util" ) func startTestingServices(t *testing.T) string { @@ -88,20 +91,15 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { return nil, nil } - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - permissionsManagerMock := permissions.NewMockManager(ctrl) - peersmanager := peers.NewManager(store, permissionsManagerMock) - settingsManagerMock := settings.NewMockManager(ctrl) - - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + permissionsManagerMock := permissions.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() settingsMockManager.EXPECT(). diff --git a/client/cmd/up.go b/client/cmd/up.go index 1b751aa55..7cc342fe0 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -63,7 +63,6 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "WireGuard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "WireGuard interface listening port") - upCmd.PersistentFlags().Uint16Var(&mtu, mtuFlag, iface.DefaultMTU, "Set MTU (Maximum Transmission Unit) for the WireGuard interface") upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+ `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, @@ -231,9 +230,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{ - WaitForReady: func() *bool { b := true; return &b }(), - }) + status, err := client.Status(ctx, &proto.StatusRequest{}) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } @@ -361,11 +358,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro req.WireguardPort = &p } - if cmd.Flag(mtuFlag).Changed { - m := int64(mtu) - req.Mtu = &m - } - if cmd.Flag(networkMonitorFlag).Changed { req.NetworkMonitor = &networkMonitor } @@ -445,13 +437,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil ic.WireguardPort = &p } - if cmd.Flag(mtuFlag).Changed { - if err := iface.ValidateMTU(mtu); err != nil { - return nil, err - } - ic.MTU = &mtu - } - if cmd.Flag(networkMonitorFlag).Changed { ic.NetworkMonitor = &networkMonitor } @@ -549,14 +534,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte loginRequest.WireguardPort = &wp } - if cmd.Flag(mtuFlag).Changed { - if err := iface.ValidateMTU(mtu); err != nil { - return nil, err - } - m := int64(mtu) - loginRequest.Mtu = &m - } - if cmd.Flag(networkMonitorFlag).Changed { loginRequest.NetworkMonitor = &networkMonitor } diff --git a/client/embed/embed.go b/client/embed/embed.go index c62efc960..79f5f0e43 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -137,7 +137,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}) + run := make(chan struct{}, 1) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..7b90000a8 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -12,7 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..1e44c7a4d 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // constants needed to manage and create iptable rules diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 3490c5dad..e9eeff863 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,7 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) func isIptablesSupported() bool { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..52979d257 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..f8fed4d80 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -22,7 +22,7 @@ import ( nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go index 32b07c330..89bddf12c 100644 --- a/client/iface/bind/control.go +++ b/client/iface/bind/control.go @@ -3,7 +3,7 @@ package bind import ( wireguard "golang.zx2c4.com/wireguard/conn" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go index caa92f05d..1926ff88f 100644 --- a/client/iface/bind/endpoint.go +++ b/client/iface/bind/endpoint.go @@ -1,17 +1,5 @@ package bind -import ( - "net" - - wgConn "golang.zx2c4.com/wireguard/conn" -) +import wgConn "golang.zx2c4.com/wireguard/conn" type Endpoint = wgConn.StdNetEndpoint - -func EndpointToUDPAddr(e Endpoint) *net.UDPAddr { - return &net.UDPAddr{ - IP: e.Addr().AsSlice(), - Port: int(e.Port()), - Zone: e.Addr().Zone(), - } -} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index ef630b9d0..41f4aec6d 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,7 +1,6 @@ package bind import ( - "context" "encoding/binary" "fmt" "net" @@ -9,16 +8,15 @@ import ( "runtime" "sync" - "github.com/pion/stun/v3" + "github.com/pion/stun/v2" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type RecvMessage struct { @@ -43,10 +41,10 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - recvChan chan RecvMessage + RecvChan chan RecvMessage transportNet transport.Net - filterFn udpmux.FilterFn + filterFn FilterFn endpoints map[netip.Addr]net.Conn endpointsMu sync.Mutex // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a @@ -56,23 +54,21 @@ type ICEBind struct { closed bool muUDPMux sync.Mutex - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *UniversalUDPMuxDefault address wgaddr.Address - mtu uint16 activityRecorder *ActivityRecorder } -func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - recvChan: make(chan RecvMessage, 1), + RecvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, endpoints: make(map[netip.Addr]net.Conn), closedChan: make(chan struct{}), closed: true, - mtu: mtu, address: address, activityRecorder: NewActivityRecorder(), } @@ -84,10 +80,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg return ib } -func (s *ICEBind) MTU() uint16 { - return s.mtu -} - func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { s.closed = false s.closedChanMu.Lock() @@ -117,7 +109,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder { } // GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { +func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() if s.udpMux == nil { @@ -156,25 +148,16 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { - select { - case <-ctx.Done(): - return - case b.recvChan <- msg: - } -} - func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = udpmux.NewUniversalUDPMuxDefault( - udpmux.UniversalUDPMuxParams{ + s.udpMux = NewUniversalUDPMuxDefault( + UniversalUDPMuxParams{ UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, WGAddress: s.address, - MTU: s.mtu, }, ) return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { @@ -280,7 +263,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo select { case <-c.closedChan: return 0, net.ErrClosed - case msg, ok := <-c.recvChan: + case msg, ok := <-c.RecvChan: if !ok { return 0, net.ErrClosed } diff --git a/client/iface/udpmux/mux.go b/client/iface/bind/udp_mux.go similarity index 65% rename from client/iface/udpmux/mux.go rename to client/iface/bind/udp_mux.go index 319724926..29e5d7937 100644 --- a/client/iface/udpmux/mux.go +++ b/client/iface/bind/udp_mux.go @@ -1,4 +1,4 @@ -package udpmux +package bind import ( "fmt" @@ -8,9 +8,9 @@ import ( "strings" "sync" - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" "github.com/pion/logging" - "github.com/pion/stun/v3" + "github.com/pion/stun/v2" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" @@ -22,9 +22,9 @@ import ( const receiveMTU = 8192 -// SingleSocketUDPMux is an implementation of the interface -type SingleSocketUDPMux struct { - params Params +// UDPMuxDefault is an implementation of the interface +type UDPMuxDefault struct { + params UDPMuxParams closedChan chan struct{} closeOnce sync.Once @@ -32,9 +32,6 @@ type SingleSocketUDPMux struct { // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn - // candidateConnMap maps local candidate IDs to their corresponding connection. - candidateConnMap map[string]*udpMuxedConn - addressMapMu sync.RWMutex addressMap map[string][]*udpMuxedConn @@ -49,8 +46,8 @@ type SingleSocketUDPMux struct { const maxAddrSize = 512 -// Params are parameters for UDPMux. -type Params struct { +// UDPMuxParams are parameters for UDPMux. +type UDPMuxParams struct { Logger logging.LeveledLogger UDPConn net.PacketConn @@ -150,19 +147,18 @@ func isZeros(ip net.IP) bool { return true } -// NewSingleSocketUDPMux creates an implementation of UDPMux -func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { +// NewUDPMuxDefault creates an implementation of UDPMux +func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { if params.Logger == nil { params.Logger = getLogger() } - mux := &SingleSocketUDPMux{ - addressMap: map[string][]*udpMuxedConn{}, - params: params, - connsIPv4: make(map[string]*udpMuxedConn), - connsIPv6: make(map[string]*udpMuxedConn), - candidateConnMap: make(map[string]*udpMuxedConn), - closedChan: make(chan struct{}, 1), + mux := &UDPMuxDefault{ + addressMap: map[string][]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address @@ -175,15 +171,15 @@ func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { return mux } -func (m *SingleSocketUDPMux) updateLocalAddresses() { +func (m *UDPMuxDefault) updateLocalAddresses() { var localAddrsForUnspecified []net.Addr if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) } else if ok && addr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection - // with SingleSocketUDPMux, so print a warn log and create a local address list for mux. - m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + // with UDPMuxDefault, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { @@ -220,13 +216,13 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() { m.mu.Unlock() } -// LocalAddr returns the listening address of this SingleSocketUDPMux -func (m *SingleSocketUDPMux) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this UDPMuxDefault +func (m *UDPMuxDefault) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on -func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { +func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { m.updateLocalAddresses() m.mu.Lock() @@ -240,7 +236,7 @@ func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network address // creates the connection if an existing one can't be found -func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { +func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { // don't check addr for mux using unspecified address m.mu.Lock() lenLocalAddrs := len(m.localAddrsForUnspecified) @@ -264,14 +260,12 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st return conn, nil } - c := m.createMuxedConn(ufrag, candidateID) + c := m.createMuxedConn(ufrag) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() - m.candidateConnMap[candidateID] = c - if isIPv6 { m.connsIPv6[ufrag] = c } else { @@ -282,7 +276,7 @@ func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID st } // RemoveConnByUfrag stops and removes the muxed packet connection -func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { +func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock @@ -290,12 +284,10 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) - delete(m.candidateConnMap, c.GetCandidateID()) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) - delete(m.candidateConnMap, c.GetCandidateID()) } m.mu.Unlock() @@ -322,7 +314,7 @@ func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { } // IsClosed returns true if the mux had been closed -func (m *SingleSocketUDPMux) IsClosed() bool { +func (m *UDPMuxDefault) IsClosed() bool { select { case <-m.closedChan: return true @@ -332,7 +324,7 @@ func (m *SingleSocketUDPMux) IsClosed() bool { } // Close the mux, no further connections could be created -func (m *SingleSocketUDPMux) Close() error { +func (m *UDPMuxDefault) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() @@ -355,11 +347,11 @@ func (m *SingleSocketUDPMux) Close() error { return err } -func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { +func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } @@ -376,109 +368,81 @@ func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr str log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } -func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { +func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ - Mux: m, - Key: key, - AddrPool: m.pool, - LocalAddr: m.LocalAddr(), - Logger: m.params.Logger, - CandidateID: candidateID, + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, }) return c } // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library -func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { +func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { + remoteAddr, ok := addr.(*net.UDPAddr) if !ok { return fmt.Errorf("underlying PacketConn did not return a UDPAddr") } - // Try to route to specific candidate connection first - if conn := m.findCandidateConnection(msg); conn != nil { - return conn.writePacket(msg.Raw, remoteAddr) - } - - // Fallback: route to all possible connections - return m.forwardToAllConnections(msg, addr, remoteAddr) -} - -// findCandidateConnection attempts to find the specific connection for a STUN message -func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn { - candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg) - if err != nil { - return nil - } else if !ok { - return nil - } - - m.mu.Lock() - defer m.mu.Unlock() - conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()] - if !exists { - return nil - } - return conn -} - -// forwardToAllConnections forwards STUN message to all relevant connections -func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error { - var destinationConnList []*udpMuxedConn - - // Add connections from address map + // If we have already seen this address dispatch to the appropriate destination + // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one + // muxed connection - one for the SRFLX candidate and the other one for the HOST one. + // We will then forward STUN packets to each of these connections. m.addressMapMu.RLock() + var destinationConnList []*udpMuxedConn if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } m.addressMapMu.RUnlock() - if conn, ok := m.findConnectionByUsername(msg, addr); ok { - // If we have already seen this address dispatch to the appropriate destination - // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one - // muxed connection - one for the SRFLX candidate and the other one for the HOST one. - // We will then forward STUN packets to each of these connections. - if !m.connectionExists(conn, destinationConnList) { - destinationConnList = append(destinationConnList, conn) - } + var isIPv6 bool + if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { + isIPv6 = true } - // Forward to all found connections + // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. + // However, we can take a username attribute from the STUN message which contains ufrag. + // We can use ufrag to identify the destination conn to route packet to. + attr, stunAttrErr := msg.Get(stun.AttrUsername) + if stunAttrErr == nil { + ufrag := strings.Split(string(attr), ":")[0] + + m.mu.Lock() + destinationConn := m.connsIPv4[ufrag] + if isIPv6 { + destinationConn = m.connsIPv6[ufrag] + } + + if destinationConn != nil { + exists := false + for _, conn := range destinationConnList { + if conn.params.Key == destinationConn.params.Key { + exists = true + break + } + } + if !exists { + destinationConnList = append(destinationConnList, destinationConn) + } + } + m.mu.Unlock() + } + + // Forward STUN packets to each destination connections even thought the STUN packet might not belong there. + // It will be discarded by the further ICE candidate logic if so. for _, conn := range destinationConnList { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { log.Errorf("could not write packet: %v", err) } } + return nil } -// findConnectionByUsername finds connection using username attribute from STUN message -func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) { - attr, err := msg.Get(stun.AttrUsername) - if err != nil { - return nil, false - } - - ufrag := strings.Split(string(attr), ":")[0] - isIPv6 := isIPv6Address(addr) - - m.mu.Lock() - defer m.mu.Unlock() - - return m.getConn(ufrag, isIPv6) -} - -// connectionExists checks if a connection already exists in the list -func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool { - for _, conn := range conns { - if conn.params.Key == target.params.Key { - return true - } - } - return false -} - -func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { +func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { @@ -487,13 +451,6 @@ func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedCo return } -func isIPv6Address(addr net.Addr) bool { - if udpAddr, ok := addr.(*net.UDPAddr); ok { - return udpAddr.IP.To4() == nil - } - return false -} - type bufferHolder struct { buf []byte } diff --git a/client/iface/udpmux/mux_generic.go b/client/iface/bind/udp_mux_generic.go similarity index 76% rename from client/iface/udpmux/mux_generic.go rename to client/iface/bind/udp_mux_generic.go index 29fc2d834..63f786d2b 100644 --- a/client/iface/udpmux/mux_generic.go +++ b/client/iface/bind/udp_mux_generic.go @@ -1,12 +1,12 @@ //go:build !ios -package udpmux +package bind import ( - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) -func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { +func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { conn.RemoveAddress(addr) diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go new file mode 100644 index 000000000..15e26d02f --- /dev/null +++ b/client/iface/bind/udp_mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package bind + +func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} \ No newline at end of file diff --git a/client/iface/udpmux/universal.go b/client/iface/bind/udp_mux_universal.go similarity index 95% rename from client/iface/udpmux/universal.go rename to client/iface/bind/udp_mux_universal.go index 43bfedaaa..b755a7827 100644 --- a/client/iface/udpmux/universal.go +++ b/client/iface/bind/udp_mux_universal.go @@ -1,4 +1,4 @@ -package udpmux +package bind /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. @@ -15,10 +15,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/pion/logging" - "github.com/pion/stun/v3" + "github.com/pion/stun/v2" "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +28,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { - *SingleSocketUDPMux + *UDPMuxDefault params UniversalUDPMuxParams // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents @@ -45,7 +44,6 @@ type UniversalUDPMuxParams struct { Net transport.Net FilterFn FilterFn WGAddress wgaddr.Address - MTU uint16 } // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux @@ -72,12 +70,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - udpMuxParams := Params{ + udpMuxParams := UDPMuxParams{ Logger: params.Logger, UDPConn: m.params.UDPConn, Net: m.params.Net, } - m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) + m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) return m } @@ -86,7 +84,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef // just ignore other packets printing an warning message. // It is a blocking method, consider running in a go routine. func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { - buf := make([]byte, m.params.MTU+bufsize.WGBufferOverhead) + buf := make([]byte, 1500) for { select { case <-ctx.Done(): @@ -211,8 +209,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { - return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) } // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. @@ -233,7 +231,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A } return nil } - return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) + return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. diff --git a/client/iface/udpmux/conn.go b/client/iface/bind/udp_muxed_conn.go similarity index 95% rename from client/iface/udpmux/conn.go rename to client/iface/bind/udp_muxed_conn.go index 3aa40caeb..7cacf1c31 100644 --- a/client/iface/udpmux/conn.go +++ b/client/iface/bind/udp_muxed_conn.go @@ -1,4 +1,4 @@ -package udpmux +package bind /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements @@ -16,12 +16,11 @@ import ( ) type udpMuxedConnParams struct { - Mux *SingleSocketUDPMux - AddrPool *sync.Pool - Key string - LocalAddr net.Addr - Logger logging.LeveledLogger - CandidateID string + Mux *UDPMuxDefault + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag @@ -120,10 +119,6 @@ func (c *udpMuxedConn) Close() error { return err } -func (c *udpMuxedConn) GetCandidateID() string { - return c.params.CandidateID -} - func (c *udpMuxedConn) isClosed() bool { select { case <-c.closedChan: diff --git a/client/iface/bufsize/bufsize.go b/client/iface/bufsize/bufsize.go deleted file mode 100644 index 0d2afb77d..000000000 --- a/client/iface/bufsize/bufsize.go +++ /dev/null @@ -1,9 +0,0 @@ -package bufsize - -const ( - // WGBufferOverhead represents the additional buffer space needed beyond MTU - // for WireGuard packet encapsulation (WG header + UDP + IP + safety margin) - // Original hardcoded buffers were 1500, default MTU is 1280, so overhead = 220 - // TODO: Calculate this properly based on actual protocol overhead instead of using hardcoded difference - WGBufferOverhead = 220 -) diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index f744e0127..171458e38 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/monotime" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -394,13 +394,6 @@ func toLastHandshake(stringVar string) (time.Time, error) { if err != nil { return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) } - - // If sec is 0 (Unix epoch), return zero time instead - // This indicates no handshake has occurred - if sec == 0 { - return time.Time{}, nil - } - return time.Unix(sec, 0), nil } @@ -409,7 +402,7 @@ func toBytes(s string) (int64, error) { } func getFwmark() int { - if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { + if nbnet.AdvancedRouting() { return nbnet.ControlPlaneMark } return 0 diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..81f2e0f47 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -7,17 +7,16 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) - Up() (*udpmux.UniversalUDPMuxDefault, error) + Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address - MTU() uint16 DeviceName() string Close() error FilteredDevice() *device.FilteredDevice diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..4fe6e466b 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,7 +13,6 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -22,7 +21,7 @@ type WGTunDevice struct { address wgaddr.Address port int key string - mtu uint16 + mtu int iceBind *bind.ICEBind tunAdapter TunAdapter disableDNS bool @@ -30,11 +29,11 @@ type WGTunDevice struct { name string device *device.Device filteredDevice *FilteredDevice - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { +func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice { return &WGTunDevice{ address: address, port: port, @@ -59,7 +58,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string searchDomainsToString = "" } - fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), int(t.mtu), dns, searchDomainsToString, routesString) + fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) if err != nil { log.Errorf("failed to create Android interface: %s", err) return nil, err @@ -89,7 +88,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string } return t.configurer, nil } -func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -138,10 +137,6 @@ func (t *WGTunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *WGTunDevice) MTU() uint16 { - return t.mtu -} - func (t *WGTunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..81de0e360 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,7 +13,6 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -22,16 +21,16 @@ type TunDevice struct { address wgaddr.Address port int key string - mtu uint16 + mtu int iceBind *bind.ICEBind device *device.Device filteredDevice *FilteredDevice - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, @@ -43,7 +42,7 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu } func (t *TunDevice) Create() (WGConfigurer, error) { - tunDevice, err := tun.CreateTUN(t.name, int(t.mtu)) + tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } @@ -72,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -112,10 +111,6 @@ func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *TunDevice) MTU() uint16 { - return t.mtu -} - func (t *TunDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..4613762c3 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,7 +14,6 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -23,23 +22,21 @@ type TunDevice struct { address wgaddr.Address port int key string - mtu uint16 iceBind *bind.ICEBind tunFd int device *device.Device filteredDevice *FilteredDevice - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunFd int) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, - mtu: mtu, iceBind: iceBind, tunFd: tunFd, } @@ -84,7 +81,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -128,10 +125,6 @@ func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *TunDevice) MTU() uint16 { - return t.mtu -} - func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { // todo implement return nil diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..7136be0bc 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -12,11 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/sharedsock" + nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -24,19 +24,19 @@ type TunKernelDevice struct { address wgaddr.Address wgPort int key string - mtu uint16 + mtu int ctx context.Context ctxCancel context.CancelFunc transportNet transport.Net link *wgLink udpMuxConn net.PacketConn - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault - filterFn udpmux.FilterFn + filterFn bind.FilterFn } -func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { +func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { ctx, cancel := context.WithCancel(context.Background()) return &TunKernelDevice{ ctx: ctx, @@ -66,7 +66,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { // TODO: do a MTU discovery log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name) - if err := link.setMTU(int(t.mtu)); err != nil { + if err := link.setMTU(t.mtu); err != nil { return nil, fmt.Errorf("set mtu: %w", err) } @@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { return configurer, nil } -func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -96,19 +96,23 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return nil, err } - rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter(), t.mtu) + rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter()) if err != nil { return nil, err } - bindParams := udpmux.UniversalUDPMuxParams{ - UDPConn: nbnet.WrapPacketConn(rawSock), + var udpConn net.PacketConn = rawSock + if !nbnet.AdvancedRouting() { + udpConn = nbnet.WrapPacketConn(rawSock) + } + + bindParams := bind.UniversalUDPMuxParams{ + UDPConn: udpConn, Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, - MTU: t.mtu, } - mux := udpmux.NewUniversalUDPMuxDefault(bindParams) + mux := bind.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) t.udpMuxConn = rawSock t.udpMux = mux @@ -154,10 +158,6 @@ func (t *TunKernelDevice) WgAddress() wgaddr.Address { return t.address } -func (t *TunKernelDevice) MTU() uint16 { - return t.mtu -} - func (t *TunKernelDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index a6ef47027..fc3cb0215 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,3 +1,6 @@ +//go:build !android +// +build !android + package device import ( @@ -10,9 +13,8 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type TunNetstackDevice struct { @@ -20,20 +22,20 @@ type TunNetstackDevice struct { address wgaddr.Address port int key string - mtu uint16 + mtu int listenAddress string iceBind *bind.ICEBind device *device.Device filteredDevice *FilteredDevice nsTun *nbnetstack.NetStackTun - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer net *netstack.Net } -func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -45,7 +47,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri } } -func (t *TunNetstackDevice) create() (WGConfigurer, error) { +func (t *TunNetstackDevice) Create() (WGConfigurer, error) { log.Info("create nbnetstack tun interface") // TODO: get from service listener runtime IP @@ -55,7 +57,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { } log.Debugf("netstack using address: %s", t.address.IP) - t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, int(t.mtu)) + t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) log.Debugf("netstack using dns address: %s", dnsAddr) tunIface, net, err := t.nsTun.Create() if err != nil { @@ -81,7 +83,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -123,10 +125,6 @@ func (t *TunNetstackDevice) WgAddress() wgaddr.Address { return t.address } -func (t *TunNetstackDevice) MTU() uint16 { - return t.mtu -} - func (t *TunNetstackDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_netstack_android.go b/client/iface/device/device_netstack_android.go deleted file mode 100644 index 45ae8ba7d..000000000 --- a/client/iface/device/device_netstack_android.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build android - -package device - -func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { - return t.create() -} diff --git a/client/iface/device/device_netstack_generic.go b/client/iface/device/device_netstack_generic.go deleted file mode 100644 index 4b3974f26..000000000 --- a/client/iface/device/device_netstack_generic.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !android - -package device - -func (t *TunNetstackDevice) Create() (WGConfigurer, error) { - return t.create() -} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..e781f6004 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,7 +12,6 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -21,16 +20,16 @@ type USPDevice struct { address wgaddr.Address port int key string - mtu uint16 + mtu int iceBind *bind.ICEBind device *device.Device filteredDevice *FilteredDevice - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer } -func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *USPDevice { +func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { log.Infof("using userspace bind mode") return &USPDevice{ @@ -45,9 +44,9 @@ func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu func (t *USPDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") - tunIface, err := tun.CreateTUN(t.name, int(t.mtu)) + tunIface, err := tun.CreateTUN(t.name, t.mtu) if err != nil { - log.Debugf("failed to create tun interface (%s, %d): %s", t.name, int(t.mtu), err) + log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) return nil, fmt.Errorf("error creating tun device: %s", err) } t.filteredDevice = newDeviceFilter(tunIface) @@ -75,7 +74,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -119,10 +118,6 @@ func (t *USPDevice) WgAddress() wgaddr.Address { return t.address } -func (t *USPDevice) MTU() uint16 { - return t.mtu -} - func (t *USPDevice) DeviceName() string { return t.name } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..0316c4b8d 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,7 +13,6 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -24,17 +23,17 @@ type TunDevice struct { address wgaddr.Address port int key string - mtu uint16 + mtu int iceBind *bind.ICEBind device *device.Device nativeTunDevice *tun.NativeTun filteredDevice *FilteredDevice - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault configurer WGConfigurer } -func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, @@ -60,7 +59,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return nil, err } log.Info("create tun interface") - tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, int(t.mtu)) + tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } @@ -105,7 +104,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -145,10 +144,6 @@ func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *TunDevice) MTU() uint16 { - return t.mtu -} - func (t *TunDevice) DeviceName() string { return t.name } diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..a1e246fc5 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -5,17 +5,16 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) - Up() (*udpmux.UniversalUDPMuxDefault, error) + Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address - MTU() uint16 DeviceName() string Close() error FilteredDevice() *device.FilteredDevice diff --git a/client/iface/iface.go b/client/iface/iface.go index 609572561..0e41f8e64 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -16,9 +16,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -26,8 +26,6 @@ import ( const ( DefaultMTU = 1280 - MinMTU = 576 - MaxMTU = 8192 DefaultWgPort = 51820 WgInterfaceDefault = configurer.WgInterfaceDefault ) @@ -37,17 +35,6 @@ var ( ErrIfaceNotFound = fmt.Errorf("wireguard interface not found") ) -// ValidateMTU validates that MTU is within acceptable range -func ValidateMTU(mtu uint16) error { - if mtu < MinMTU { - return fmt.Errorf("MTU %d below minimum (%d bytes)", mtu, MinMTU) - } - if mtu > MaxMTU { - return fmt.Errorf("MTU %d exceeds maximum supported size (%d bytes)", mtu, MaxMTU) - } - return nil -} - type wgProxyFactory interface { GetProxy() wgproxy.Proxy Free() error @@ -58,10 +45,10 @@ type WGIFaceOpts struct { Address string WGPort int WGPrivKey string - MTU uint16 + MTU int MobileArgs *device.MobileIFaceArguments TransportNet transport.Net - FilterFn udpmux.FilterFn + FilterFn bind.FilterFn DisableDNS bool } @@ -95,10 +82,6 @@ func (w *WGIface) Address() wgaddr.Address { return w.tun.WgAddress() } -func (w *WGIface) MTU() uint16 { - return w.tun.MTU() -} - // ToInterface returns the net.Interface for the Wireguard interface func (r *WGIface) ToInterface() *net.Interface { name := r.tun.DeviceName() @@ -114,7 +97,7 @@ func (r *WGIface) ToInterface() *net.Interface { // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 26952f48d..c8babea32 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -3,7 +3,6 @@ package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) @@ -15,16 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - - if netstack.IsEnabled() { - wgIFace := &WGIface{ - userspaceBind: true, - tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), - } - return wgIFace, nil - } + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace := &WGIface{ userspaceBind: true, diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index 7dd74d571..93fd7fd5c 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -17,7 +17,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) var tun WGTunDevice if netstack.IsEnabled() { diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go deleted file mode 100644 index 86ed14ce1..000000000 --- a/client/iface/iface_new_freebsd.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build freebsd - -package iface - -import ( - "fmt" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/wgaddr" - "github.com/netbirdio/netbird/client/iface/wgproxy" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := wgaddr.ParseWGAddress(opts.Address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - - if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) - return wgIFace, nil - } - - if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) - wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) - wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) - return wgIFace, nil - } - - return nil, fmt.Errorf("couldn't check or load tun module") -} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 06ccf0be1..317ee0f46 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -16,10 +16,10 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace := &WGIface{ - tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), + tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, wgProxyFactory: wgproxy.NewUSPFactory(iceBind), } diff --git a/client/iface/iface_new_linux.go b/client/iface/iface_new_unix.go similarity index 90% rename from client/iface/iface_new_linux.go rename to client/iface/iface_new_unix.go index 77fd30fae..23ee7236f 100644 --- a/client/iface/iface_new_linux.go +++ b/client/iface/iface_new_unix.go @@ -1,4 +1,4 @@ -//go:build linux && !android +//go:build (linux && !android) || freebsd package iface @@ -22,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{} if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) @@ -31,11 +31,11 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { if device.WireGuardModuleIsLoaded() { wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) - wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort, opts.MTU) + wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort) return wgIFace, nil } if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 349c5b33b..413062940 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -14,7 +14,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { if err != nil { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) var tun WGTunDevice if netstack.IsEnabled() { diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go deleted file mode 100644 index 27e5e43bc..000000000 --- a/client/iface/udpmux/doc.go +++ /dev/null @@ -1,64 +0,0 @@ -// Package udpmux provides a custom implementation of a UDP multiplexer -// that allows multiple logical ICE connections to share a single underlying -// UDP socket. This is based on Pion's ICE library, with modifications for -// NetBird's requirements. -// -// # Background -// -// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity -// Establishment) is responsible for discovering candidate network paths -// and maintaining connectivity between peers. Each ICE connection -// normally requires a dedicated UDP socket. However, using one socket -// per candidate can be inefficient and difficult to manage. -// -// This package introduces SingleSocketUDPMux, which allows multiple ICE -// candidate connections (muxed connections) to share a single UDP socket. -// It handles demultiplexing of packets based on ICE ufrag values, STUN -// attributes, and candidate IDs. -// -// # Usage -// -// The typical flow is: -// -// 1. Create a UDP socket (net.PacketConn). -// 2. Construct Params with the socket and optional logger/net stack. -// 3. Call NewSingleSocketUDPMux(params). -// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID) -// to obtain a logical PacketConn. -// 5. Use the returned PacketConn just like a normal UDP connection. -// -// # STUN Message Routing Logic -// -// When a STUN packet arrives, the mux decides which connection should -// receive it using this routing logic: -// -// Primary Routing: Candidate Pair ID -// - Extract the candidate pair ID from the STUN message using -// ice.CandidatePairIDFromSTUN(msg) -// - The target candidate is the locally generated candidate that -// corresponds to the connection that should handle this STUN message -// - If found, use the target candidate ID to lookup the specific -// connection in candidateConnMap -// - Route the message directly to that connection -// -// Fallback Routing: Broadcasting -// When candidate pair ID is not available or lookup fails: -// - Collect connections from addressMap based on source address -// - Find connection using username attribute (ufrag) from STUN message -// - Remove duplicate connections from the list -// - Send the STUN message to all collected connections -// -// # Peer Reflexive Candidate Discovery -// -// When a remote peer sends a STUN message from an unknown source address -// (from a candidate that has not been exchanged via signal), the ICE -// library will: -// - Generate a new peer reflexive candidate for this source address -// - Extract or assign a candidate ID based on the STUN message attributes -// - Create a mapping between the new peer reflexive candidate ID and -// the appropriate local connection -// -// This discovery mechanism ensures that STUN messages from newly discovered -// peer reflexive candidates can be properly routed to the correct local -// connection without requiring fallback broadcasting. -package udpmux diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go deleted file mode 100644 index 4cf211d8f..000000000 --- a/client/iface/udpmux/mux_ios.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build ios - -package udpmux - -func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { - // iOS doesn't support nbnet hooks, so this is a no-op -} diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index dbc694e91..f68e84810 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -12,41 +12,31 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) -type IceBind interface { - SetEndpoint(fakeIP netip.Addr, conn net.Conn) - RemoveEndpoint(fakeIP netip.Addr) - Recv(ctx context.Context, msg bind.RecvMessage) - MTU() uint16 -} - type ProxyBind struct { - bind IceBind + Bind *bind.ICEBind - // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address - wgRelayedEndpoint *bind.Endpoint - wgCurrentUsed *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + fakeNetIP *netip.AddrPort + wgBindEndpoint *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - paused bool - pausedCond *sync.Cond - isStarted bool + pausedMu sync.Mutex + paused bool + isStarted bool closeListener *listener.CloseListener } -func NewProxyBind(bind IceBind) *ProxyBind { +func NewProxyBind(bind *bind.ICEBind) *ProxyBind { p := &ProxyBind{ - bind: bind, + Bind: bind, closeListener: listener.NewCloseListener(), - pausedCond: sync.NewCond(&sync.Mutex{}), } return p @@ -55,25 +45,25 @@ func NewProxyBind(bind IceBind) *ProxyBind { // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. -// -// Parameters: -// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages -// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address -// - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} + + p.fakeNetIP = fakeNetIP + p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } - func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) + return &net.UDPAddr{ + IP: p.fakeNetIP.Addr().AsSlice(), + Port: int(p.fakeNetIP.Port()), + Zone: p.fakeNetIP.Addr().Zone(), + } } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { @@ -85,21 +75,17 @@ func (p *ProxyBind) Work() { return } - p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) + p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) - p.pausedCond.L.Lock() + p.pausedMu.Lock() p.paused = false - - p.wgCurrentUsed = p.wgRelayedEndpoint + p.pausedMu.Unlock() // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } - - p.pausedCond.Signal() - p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { @@ -107,25 +93,9 @@ func (p *ProxyBind) Pause() { return } - p.pausedCond.L.Lock() + p.pausedMu.Lock() p.paused = true - p.pausedCond.L.Unlock() -} - -func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { - p.pausedCond.L.Lock() - p.paused = false - - p.wgCurrentUsed = addrToEndpoint(endpoint) - - p.pausedCond.Signal() - p.pausedCond.L.Unlock() -} - -func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { - ip, _ := netip.AddrFromSlice(addr.IP.To4()) - addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) - return &bind.Endpoint{AddrPort: addrPort} + p.pausedMu.Unlock() } func (p *ProxyBind) CloseConn() error { @@ -136,10 +106,6 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { - if p.remoteConn == nil { - return nil - } - p.closeMu.Lock() defer p.closeMu.Unlock() @@ -153,12 +119,7 @@ func (p *ProxyBind) close() error { p.cancel() - p.pausedCond.L.Lock() - p.paused = false - p.pausedCond.Signal() - p.pausedCond.L.Unlock() - - p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) + p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -174,7 +135,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, 1500) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -185,17 +146,18 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedCond.L.Lock() - for p.paused { - p.pausedCond.Wait() + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue } msg := bind.RecvMessage{ - Endpoint: p.wgCurrentUsed, + Endpoint: p.wgBindEndpoint, Buffer: buf[:n], } - p.bind.Recv(ctx, msg) - p.pausedCond.L.Unlock() + p.Bind.RecvChan <- msg + p.pausedMu.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 858143091..e21fc35d4 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -6,7 +6,9 @@ import ( "context" "fmt" "net" + "os" "sync" + "syscall" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -15,25 +17,18 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bufsize" - "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( loopbackAddr = "127.0.0.1" ) -var ( - localHostNetIP = net.ParseIP("127.0.0.1") -) - // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int - mtu uint16 ebpfManager ebpfMgr.Manager turnConnStore map[uint16]net.Conn @@ -48,11 +43,10 @@ type WGEBPFProxy struct { } // NewWGEBPFProxy create new WGEBPFProxy instance -func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy { +func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { log.Debugf("instantiate ebpf proxy") wgProxy := &WGEBPFProxy{ localWGListenPort: wgPort, - mtu: mtu, ebpfManager: ebpf.GetEbpfManagerInstance(), turnConnStore: make(map[uint16]net.Conn), } @@ -67,7 +61,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = rawsocket.PrepareSenderRawSocket() + p.rawConn, err = p.prepareSenderRawSocket() if err != nil { return err } @@ -144,7 +138,7 @@ func (p *WGEBPFProxy) Free() error { // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { - buf := make([]byte, p.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, 1500) for p.ctx.Err() == nil { if err := p.readAndForwardPacket(buf); err != nil { if p.ctx.Err() != nil { @@ -217,17 +211,57 @@ generatePort: return p.lastUsedPort, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { +func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return nil, fmt.Errorf("creating raw socket failed: %w", err) + } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) + if err != nil { + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + + // Bind the socket to the "lo" interface. + err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") + if err != nil { + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = nbnet.SetSocketOpt(fd) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) + } + + return packetConn, nil +} + +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { + localhost := net.ParseIP("127.0.0.1") + payload := gopacket.Payload(data) ipH := &layers.IPv4{ - DstIP: localHostNetIP, - SrcIP: endpointAddr.IP, + DstIP: localhost, + SrcIP: localhost, Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, } udpH := &layers.UDP{ - SrcPort: layers.UDPPort(endpointAddr.Port), + SrcPort: layers.UDPPort(port), DstPort: layers.UDPPort(p.localWGListenPort), } @@ -242,7 +276,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { if err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/proxy_test.go b/client/iface/wgproxy/ebpf/proxy_test.go index 3ec4f0eba..b15bc686c 100644 --- a/client/iface/wgproxy/ebpf/proxy_test.go +++ b/client/iface/wgproxy/ebpf/proxy_test.go @@ -7,7 +7,7 @@ import ( ) func TestWGEBPFProxy_connStore(t *testing.T) { - wgProxy := NewWGEBPFProxy(1, 1280) + wgProxy := NewWGEBPFProxy(1) p, _ := wgProxy.storeTurnConn(nil) if p != 1 { @@ -27,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) { } func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { - wgProxy := NewWGEBPFProxy(1, 1280) + wgProxy := NewWGEBPFProxy(1) _, _ = wgProxy.storeTurnConn(nil) wgProxy.lastUsedPort = 65535 @@ -43,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { } func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { - wgProxy := NewWGEBPFProxy(1, 1280) + wgProxy := NewWGEBPFProxy(1) for i := 0; i < 65535; i++ { _, _ = wgProxy.storeTurnConn(nil) diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index ff44d30c0..b25dc4198 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -12,48 +12,46 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { - wgeBPFProxy *WGEBPFProxy + WgeBPFProxy *WGEBPFProxy remoteConn net.Conn ctx context.Context cancel context.CancelFunc - wgRelayedEndpointAddr *net.UDPAddr - wgEndpointCurrentUsedAddr *net.UDPAddr + wgEndpointAddr *net.UDPAddr - paused bool - pausedCond *sync.Cond - isStarted bool + pausedMu sync.Mutex + paused bool + isStarted bool closeListener *listener.CloseListener } -func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { +func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { return &ProxyWrapper{ - wgeBPFProxy: proxy, - pausedCond: sync.NewCond(&sync.Mutex{}), + WgeBPFProxy: WgeBPFProxy, closeListener: listener.NewCloseListener(), } } + func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - p.wgRelayedEndpointAddr = addr + p.wgEndpointAddr = addr return err } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { - return p.wgRelayedEndpointAddr + return p.wgEndpointAddr } func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { @@ -65,18 +63,14 @@ func (p *ProxyWrapper) Work() { return } - p.pausedCond.L.Lock() + p.pausedMu.Lock() p.paused = false - - p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr + p.pausedMu.Unlock() if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } - - p.pausedCond.Signal() - p.pausedCond.L.Unlock() } func (p *ProxyWrapper) Pause() { @@ -85,59 +79,45 @@ func (p *ProxyWrapper) Pause() { } log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) - p.pausedCond.L.Lock() + p.pausedMu.Lock() p.paused = true - p.pausedCond.L.Unlock() -} - -func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { - p.pausedCond.L.Lock() - p.paused = false - - p.wgEndpointCurrentUsedAddr = endpoint - - p.pausedCond.Signal() - p.pausedCond.L.Unlock() + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map -func (p *ProxyWrapper) CloseConn() error { - if p.cancel == nil { +func (e *ProxyWrapper) CloseConn() error { + if e.cancel == nil { return fmt.Errorf("proxy not started") } - p.cancel() + e.cancel() - p.closeListener.SetCloseListener(nil) + e.closeListener.SetCloseListener(nil) - p.pausedCond.L.Lock() - p.paused = false - p.pausedCond.Signal() - p.pausedCond.L.Unlock() - - if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("failed to close remote conn: %w", err) + if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("close remote conn: %w", err) } return nil } func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { - defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) - buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, 1500) for { n, err := p.readFromRemote(ctx, buf) if err != nil { return } - p.pausedCond.L.Lock() - for p.paused { - p.pausedCond.Wait() + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue } - err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) - p.pausedCond.L.Unlock() + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() if err != nil { if ctx.Err() != nil { @@ -156,7 +136,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } p.closeListener.Notify() if !errors.Is(err, io.EOF) { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) } return 0, err } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index ad2807546..e62cd97be 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -11,18 +11,16 @@ import ( type KernelFactory struct { wgPort int - mtu uint16 ebpfProxy *ebpf.WGEBPFProxy } -func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { +func NewKernelFactory(wgPort int) *KernelFactory { f := &KernelFactory{ wgPort: wgPort, - mtu: mtu, } - ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu) + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) if err := ebpfProxy.Listen(); err != nil { log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) @@ -35,10 +33,11 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { func (w *KernelFactory) GetProxy() Proxy { if w.ebpfProxy == nil { - return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu) + return udpProxy.NewWGUDPProxy(w.wgPort) } return ebpf.NewProxyWrapper(w.ebpfProxy) + } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go new file mode 100644 index 000000000..736944229 --- /dev/null +++ b/client/iface/wgproxy/factory_kernel_freebsd.go @@ -0,0 +1,29 @@ +package wgproxy + +import ( + log "github.com/sirupsen/logrus" + + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +// KernelFactory todo: check eBPF support on FreeBSD +type KernelFactory struct { + wgPort int +} + +func NewKernelFactory(wgPort int) *KernelFactory { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") + f := &KernelFactory{ + wgPort: wgPort, + } + + return f +} + +func (w *KernelFactory) GetProxy() Proxy { + return udpProxy.NewWGUDPProxy(w.wgPort) +} + +func (w *KernelFactory) Free() error { + return nil +} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index 3c8dfd30e..c2879877e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -11,11 +11,6 @@ type Proxy interface { EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. - - //RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused - //and rewrite the src address to the endpoint address. - //With this logic can avoid the package loss from relayed connections. - RedirectAs(endpoint *net.UDPAddr) CloseConn() error SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 9526e91d2..298c98cc0 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -3,82 +3,54 @@ package wgproxy import ( - "fmt" - "net" + "context" + "os" + "testing" - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/wgaddr" - bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) -func seedProxies() ([]proxyInstance, error) { - pl := make([]proxyInstance, 0) +func TestProxyCloseByRemoteConnEBPF(t *testing.T) { + if os.Getenv("GITHUB_ACTIONS") != "true" { + t.Skip("Skipping test as it requires root privileges") + } + ctx := context.Background() - ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) + ebpfProxy := ebpf.NewWGEBPFProxy(51831) if err := ebpfProxy.Listen(); err != nil { - return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) + t.Fatalf("failed to initialize ebpf proxy: %s", err) } - pEbpf := proxyInstance{ - name: "ebpf kernel proxy", - proxy: ebpf.NewProxyWrapper(ebpfProxy), - wgPort: 51831, - closeFn: ebpfProxy.Free, - } - pl = append(pl, pEbpf) + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %s", err) + } + }() - pUDP := proxyInstance{ - name: "udp kernel proxy", - proxy: udp.NewWGUDPProxy(51832, 1280), - wgPort: 51832, - closeFn: func() error { return nil }, + tests := []struct { + name string + proxy Proxy + }{ + { + name: "ebpf proxy", + proxy: &ebpf.ProxyWrapper{ + WgeBPFProxy: ebpfProxy, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + relayedConn := newMockConn() + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + if err != nil { + t.Errorf("error: %v", err) + } + + _ = relayedConn.Close() + if err := tt.proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }) } - pl = append(pl, pUDP) - return pl, nil -} - -func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { - pl := make([]proxyInstance, 0) - - ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) - if err := ebpfProxy.Listen(); err != nil { - return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) - } - - pEbpf := proxyInstance{ - name: "ebpf kernel proxy", - proxy: ebpf.NewProxyWrapper(ebpfProxy), - wgPort: 51831, - closeFn: ebpfProxy.Free, - } - pl = append(pl, pEbpf) - - pUDP := proxyInstance{ - name: "udp kernel proxy", - proxy: udp.NewWGUDPProxy(51832, 1280), - wgPort: 51832, - closeFn: func() error { return nil }, - } - pl = append(pl, pUDP) - wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") - if err != nil { - return nil, err - } - iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) - endpointAddress := &net.UDPAddr{ - IP: net.IPv4(10, 0, 0, 1), - Port: 1234, - } - - pBind := proxyInstance{ - name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), - endpointAddr: endpointAddress, - closeFn: func() error { return nil }, - } - pl = append(pl, pBind) - - return pl, nil } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go deleted file mode 100644 index 4d244f18a..000000000 --- a/client/iface/wgproxy/proxy_seed_test.go +++ /dev/null @@ -1,39 +0,0 @@ -//go:build !linux - -package wgproxy - -import ( - "net" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/wgaddr" - bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" -) - -func seedProxies() ([]proxyInstance, error) { - // todo extend with Bind proxy - pl := make([]proxyInstance, 0) - return pl, nil -} - -func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { - pl := make([]proxyInstance, 0) - wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") - if err != nil { - return nil, err - } - iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) - endpointAddress := &net.UDPAddr{ - IP: net.IPv4(10, 0, 0, 1), - Port: 1234, - } - - pBind := proxyInstance{ - name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), - endpointAddr: endpointAddress, - closeFn: func() error { return nil }, - } - pl = append(pl, pBind) - return pl, nil -} diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 1aeab66b7..6882f9ea2 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -1,3 +1,5 @@ +//go:build linux + package wgproxy import ( @@ -5,9 +7,12 @@ import ( "io" "net" "os" + "runtime" "testing" "time" + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -17,14 +22,6 @@ func TestMain(m *testing.M) { os.Exit(code) } -type proxyInstance struct { - name string - proxy Proxy - wgPort int - endpointAddr *net.UDPAddr - closeFn func() error -} - type mocConn struct { closeChan chan struct{} closed bool @@ -81,21 +78,41 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error { func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() - tests, err := seedProxyForProxyCloseByRemoteConn() - if err != nil { - t.Fatalf("error: %v", err) + tests := []struct { + name string + proxy Proxy + }{ + { + name: "userspace proxy", + proxy: udpProxy.NewWGUDPProxy(51830), + }, } - relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") - defer func() { - _ = relayedConn.Close() - }() + if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { + ebpfProxy := ebpf.NewWGEBPFProxy(51831) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %s", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %s", err) + } + }() + proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) + + tests = append(tests, struct { + name string + proxy Proxy + }{ + name: "ebpf proxy", + proxy: proxyWrapper, + }) + } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892") relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) if err != nil { t.Errorf("error: %v", err) } @@ -107,104 +124,3 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }) } } - -// TestProxyRedirect todo extend the proxies with Bind proxy -func TestProxyRedirect(t *testing.T) { - tests, err := seedProxies() - if err != nil { - t.Fatalf("error: %v", err) - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) - if err := tt.closeFn(); err != nil { - t.Errorf("error: %v", err) - } - }) - } -} - -func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { - t.Helper() - - msgHelloFromRelay := []byte("hello from relay") - msgRedirected := [][]byte{ - []byte("hello 1. to p2p"), - []byte("hello 2. to p2p"), - []byte("hello 3. to p2p"), - } - - dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: wgPort}) - if err != nil { - t.Fatalf("failed to listen on udp port: %s", err) - } - - relayedServer, _ := net.ListenUDP("udp", - &net.UDPAddr{ - IP: net.IPv4(127, 0, 0, 1), - Port: 1234, - }, - ) - - relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") - - defer func() { - _ = dummyWgListener.Close() - _ = relayedConn.Close() - _ = relayedServer.Close() - }() - - if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { - t.Errorf("error: %v", err) - } - defer func() { - if err := proxy.CloseConn(); err != nil { - t.Errorf("error: %v", err) - } - }() - - proxy.Work() - - if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil { - t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err) - } - - n, err := dummyWgListener.Read(make([]byte, 1024)) - if err != nil { - t.Errorf("error: %v", err) - } - - if n != len(msgHelloFromRelay) { - t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n) - } - - p2pEndpointAddr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 0, 56), - Port: 1234, - } - proxy.RedirectAs(p2pEndpointAddr) - - for _, msg := range msgRedirected { - if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil { - t.Errorf("error: %v", err) - } - } - - for i := 0; i < len(msgRedirected); i++ { - buf := make([]byte, 1024) - n, rAddr, err := dummyWgListener.ReadFrom(buf) - if err != nil { - t.Errorf("error: %v", err) - } - - if rAddr.String() != p2pEndpointAddr.String() { - t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String()) - } - if string(buf[:n]) != string(msgRedirected[i]) { - t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n])) - } - } -} diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go deleted file mode 100644 index a11ac46d5..000000000 --- a/client/iface/wgproxy/rawsocket/rawsocket.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build linux && !android - -package rawsocket - -import ( - "fmt" - "net" - "os" - "syscall" - - nbnet "github.com/netbirdio/netbird/client/net" -) - -func PrepareSenderRawSocket() (net.PacketConn, error) { - // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) - if err != nil { - return nil, fmt.Errorf("creating raw socket failed: %w", err) - } - - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) - } - - // Bind the socket to the "lo" interface. - err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") - if err != nil { - return nil, fmt.Errorf("binding to lo interface failed: %w", err) - } - - // Set the fwmark on the socket. - err = nbnet.SetSocketOpt(fd) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - - // Convert the file descriptor to a PacketConn. - file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) - if file == nil { - return nil, fmt.Errorf("converting fd to file failed") - } - packetConn, err := net.FilePacketConn(file) - if err != nil { - return nil, fmt.Errorf("converting file to packet conn failed: %w", err) - } - - return packetConn, nil -} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 4ef2f19c4..139ccd4ed 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,5 +1,3 @@ -//go:build linux && !android - package udp import ( @@ -14,38 +12,32 @@ import ( log "github.com/sirupsen/logrus" cerrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) // WGUDPProxy proxies type WGUDPProxy struct { localWGListenPort int - mtu uint16 - remoteConn net.Conn - localConn net.Conn - srcFakerConn *SrcFaker - sendPkg func(data []byte) (int, error) - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + remoteConn net.Conn + localConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - paused bool - pausedCond *sync.Cond - isStarted bool + pausedMu sync.Mutex + paused bool + isStarted bool closeListener *listener.CloseListener } // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation -func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { +func NewWGUDPProxy(wgPort int) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) p := &WGUDPProxy{ localWGListenPort: wgPort, - mtu: mtu, - pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } return p @@ -66,7 +58,6 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem p.ctx, p.cancel = context.WithCancel(ctx) p.localConn = localConn - p.sendPkg = p.localConn.Write p.remoteConn = remoteConn return err @@ -90,24 +81,15 @@ func (p *WGUDPProxy) Work() { return } - p.pausedCond.L.Lock() + p.pausedMu.Lock() p.paused = false - p.sendPkg = p.localConn.Write - - if p.srcFakerConn != nil { - if err := p.srcFakerConn.Close(); err != nil { - log.Errorf("failed to close src faker conn: %s", err) - } - p.srcFakerConn = nil - } + p.pausedMu.Unlock() if !p.isStarted { p.isStarted = true go p.proxyToRemote(p.ctx) go p.proxyToLocal(p.ctx) } - p.pausedCond.Signal() - p.pausedCond.L.Unlock() } // Pause pauses the proxy from receiving data from the remote peer @@ -116,35 +98,9 @@ func (p *WGUDPProxy) Pause() { return } - p.pausedCond.L.Lock() + p.pausedMu.Lock() p.paused = true - p.pausedCond.L.Unlock() -} - -// RedirectAs start to use the fake sourced raw socket as package sender -func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) { - p.pausedCond.L.Lock() - defer func() { - p.pausedCond.Signal() - p.pausedCond.L.Unlock() - }() - - p.paused = false - if p.srcFakerConn != nil { - if err := p.srcFakerConn.Close(); err != nil { - log.Errorf("failed to close src faker conn: %s", err) - } - p.srcFakerConn = nil - } - srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint) - if err != nil { - log.Errorf("failed to create src faker conn: %s", err) - // fallback to continue without redirecting - p.paused = true - return - } - p.srcFakerConn = srcFakerConn - p.sendPkg = p.srcFakerConn.SendPkg + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -156,8 +112,6 @@ func (p *WGUDPProxy) CloseConn() error { } func (p *WGUDPProxy) close() error { - var result *multierror.Error - p.closeMu.Lock() defer p.closeMu.Unlock() @@ -171,11 +125,7 @@ func (p *WGUDPProxy) close() error { p.cancel() - p.pausedCond.L.Lock() - p.paused = false - p.pausedCond.Signal() - p.pausedCond.L.Unlock() - + var result *multierror.Error if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } @@ -183,13 +133,6 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } - - if p.srcFakerConn != nil { - if err := p.srcFakerConn.Close(); err != nil { - result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err)) - } - } - return cerrors.FormatErrorOrNil(result) } @@ -201,7 +144,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { } }() - buf := make([]byte, p.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, 1500) for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { @@ -236,7 +179,7 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { } }() - buf := make([]byte, p.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, 1500) for { n, err := p.remoteConnRead(ctx, buf) if err != nil { @@ -248,12 +191,14 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { return } - p.pausedCond.L.Lock() - for p.paused { - p.pausedCond.Wait() + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue } - _, err = p.sendPkg(buf[:n]) - p.pausedCond.L.Unlock() + + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() if err != nil { if ctx.Err() != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go deleted file mode 100644 index fdc911463..000000000 --- a/client/iface/wgproxy/udp/rawsocket.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build linux && !android - -package udp - -import ( - "fmt" - "net" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" -) - -var ( - serializeOpts = gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } - - localHostNetIPAddr = &net.IPAddr{ - IP: net.ParseIP("127.0.0.1"), - } -) - -type SrcFaker struct { - srcAddr *net.UDPAddr - - rawSocket net.PacketConn - ipH gopacket.SerializableLayer - udpH gopacket.SerializableLayer - layerBuffer gopacket.SerializeBuffer -} - -func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { - rawSocket, err := rawsocket.PrepareSenderRawSocket() - if err != nil { - return nil, err - } - - ipH, udpH, err := prepareHeaders(dstPort, srcAddr) - if err != nil { - return nil, err - } - - f := &SrcFaker{ - srcAddr: srcAddr, - rawSocket: rawSocket, - ipH: ipH, - udpH: udpH, - layerBuffer: gopacket.NewSerializeBuffer(), - } - - return f, nil -} - -func (f *SrcFaker) Close() error { - return f.rawSocket.Close() -} - -func (f *SrcFaker) SendPkg(data []byte) (int, error) { - defer func() { - if err := f.layerBuffer.Clear(); err != nil { - log.Errorf("failed to clear layer buffer: %s", err) - } - }() - - payload := gopacket.Payload(data) - - err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload) - if err != nil { - return 0, fmt.Errorf("serialize layers: %w", err) - } - n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) - if err != nil { - return 0, fmt.Errorf("write to raw conn: %w", err) - } - return n, nil -} - -func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { - ipH := &layers.IPv4{ - DstIP: net.ParseIP("127.0.0.1"), - SrcIP: srcAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - } - udpH := &layers.UDP{ - SrcPort: layers.UDPPort(srcAddr.Port), - DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port - } - - err := udpH.SetNetworkLayerForChecksum(ipH) - if err != nil { - return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) - } - - return ipH, udpH, nil -} diff --git a/client/internal/auth/device_flow_test.go b/client/internal/auth/device_flow_test.go index 466645ee9..dc950ac63 100644 --- a/client/internal/auth/device_flow_test.go +++ b/client/internal/auth/device_flow_test.go @@ -3,17 +3,15 @@ package auth import ( "context" "fmt" + "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/client/internal" + "github.com/stretchr/testify/require" "io" "net/http" "net/url" "strings" "testing" "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" ) type mockHTTPClient struct { diff --git a/client/internal/connect.go b/client/internal/connect.go index 295d35a43..b62a2d951 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -18,7 +18,6 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -34,7 +33,7 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -247,15 +246,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.statusRecorder.MarkSignalConnected() relayURLs, token := parseRelayInfo(loginResp) - peerConfig := loginResp.GetPeerConfig() - - engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) - if err != nil { - log.Error(err) - return wrapErr(err) - } - - relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) + relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { @@ -271,6 +262,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } peerConfig := loginResp.GetPeerConfig() + engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, c.LogFile) if err != nil { log.Error(err) @@ -284,7 +276,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engineMutex.Unlock() - if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { + if err := c.engine.Start(); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } @@ -293,8 +285,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan state.Set(StatusConnected) if runningChan != nil { - close(runningChan) - runningChan = nil + select { + case runningChan <- struct{}{}: + default: + } } <-engineCtx.Done() @@ -453,8 +447,8 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf LazyConnectionEnabled: config.LazyConnectionEnabled, LogFile: logFile, + ProfileConfig: config, - MTU: selectMTU(config.MTU, peerConfig.Mtu), } if config.PreSharedKey != "" { @@ -477,20 +471,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf return engineConf, nil } -func selectMTU(localMTU uint16, peerMTU int32) uint16 { - var finalMTU uint16 = iface.DefaultMTU - if localMTU > 0 { - finalMTU = localMTU - } else if peerMTU > 0 { - finalMTU = uint16(peerMTU) - } - - // Set global DNS MTU - dns.SetCurrentMTU(finalMTU) - - return finalMTU -} - // connectToSignal creates Signal Service client and established a connection func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) { var sigTLSEnabled bool diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go deleted file mode 100644 index cb651f1e5..000000000 --- a/client/internal/dns/config/domains.go +++ /dev/null @@ -1,201 +0,0 @@ -package config - -import ( - "errors" - "fmt" - "net" - "net/netip" - "net/url" - "strings" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/shared/management/domain" - mgmProto "github.com/netbirdio/netbird/shared/management/proto" -) - -var ( - ErrEmptyURL = errors.New("empty URL") - ErrEmptyHost = errors.New("empty host") - ErrIPNotAllowed = errors.New("IP address not allowed") -) - -// ServerDomains represents the management server domains extracted from NetBird configuration -type ServerDomains struct { - Signal domain.Domain - Relay []domain.Domain - Flow domain.Domain - Stuns []domain.Domain - Turns []domain.Domain -} - -// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration -func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { - if config == nil { - return ServerDomains{} - } - - domains := ServerDomains{} - - domains.Signal = extractSignalDomain(config) - domains.Relay = extractRelayDomains(config) - domains.Flow = extractFlowDomain(config) - domains.Stuns = extractStunDomains(config) - domains.Turns = extractTurnDomains(config) - - return domains -} - -// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses -func ExtractValidDomain(rawURL string) (domain.Domain, error) { - if rawURL == "" { - return "", ErrEmptyURL - } - - parsedURL, err := url.Parse(rawURL) - if err == nil { - if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" { - return domain, err - } - } - - return extractFromRawString(rawURL) -} - -// extractFromParsedURL handles domain extraction from successfully parsed URLs -func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) { - if parsedURL.Hostname() != "" { - return extractDomainFromHost(parsedURL.Hostname()) - } - - if parsedURL.Opaque == "" || parsedURL.Scheme == "" { - return "", nil - } - - // Handle URLs with opaque content (e.g., stun:host:port) - if strings.Contains(parsedURL.Scheme, ".") { - // This is likely "domain.com:port" being parsed as scheme:opaque - reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque - if host, _, err := net.SplitHostPort(reconstructed); err == nil { - return extractDomainFromHost(host) - } - return extractDomainFromHost(parsedURL.Scheme) - } - - // Valid scheme with opaque content (e.g., stun:host:port) - host := parsedURL.Opaque - if queryIndex := strings.Index(host, "?"); queryIndex > 0 { - host = host[:queryIndex] - } - - if hostOnly, _, err := net.SplitHostPort(host); err == nil { - return extractDomainFromHost(hostOnly) - } - - return extractDomainFromHost(host) -} - -// extractFromRawString handles domain extraction when URL parsing fails or returns no results -func extractFromRawString(rawURL string) (domain.Domain, error) { - if host, _, err := net.SplitHostPort(rawURL); err == nil { - return extractDomainFromHost(host) - } - - return extractDomainFromHost(rawURL) -} - -// extractDomainFromHost extracts domain from a host string, filtering out IP addresses -func extractDomainFromHost(host string) (domain.Domain, error) { - if host == "" { - return "", ErrEmptyHost - } - - if _, err := netip.ParseAddr(host); err == nil { - return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) - } - - d, err := domain.FromString(host) - if err != nil { - return "", fmt.Errorf("invalid domain: %v", err) - } - - return d, nil -} - -// extractSingleDomain extracts a single domain from a URL with error logging -func extractSingleDomain(url, serviceType string) domain.Domain { - if url == "" { - return "" - } - - d, err := ExtractValidDomain(url) - if err != nil { - log.Debugf("Skipping %s: %v", serviceType, err) - return "" - } - - return d -} - -// extractMultipleDomains extracts multiple domains from URLs with error logging -func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { - var domains []domain.Domain - for _, url := range urls { - if url == "" { - continue - } - d, err := ExtractValidDomain(url) - if err != nil { - log.Debugf("Skipping %s: %v", serviceType, err) - continue - } - domains = append(domains, d) - } - return domains -} - -// extractSignalDomain extracts the signal domain from NetBird configuration. -func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { - if config.Signal != nil { - return extractSingleDomain(config.Signal.Uri, "signal") - } - return "" -} - -// extractRelayDomains extracts relay server domains from NetBird configuration. -func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { - if config.Relay != nil { - return extractMultipleDomains(config.Relay.Urls, "relay") - } - return nil -} - -// extractFlowDomain extracts the traffic flow domain from NetBird configuration. -func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { - if config.Flow != nil { - return extractSingleDomain(config.Flow.Url, "flow") - } - return "" -} - -// extractStunDomains extracts STUN server domains from NetBird configuration. -func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { - var urls []string - for _, stun := range config.Stuns { - if stun != nil && stun.Uri != "" { - urls = append(urls, stun.Uri) - } - } - return extractMultipleDomains(urls, "STUN") -} - -// extractTurnDomains extracts TURN server domains from NetBird configuration. -func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { - var urls []string - for _, turn := range config.Turns { - if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { - urls = append(urls, turn.HostConfig.Uri) - } - } - return extractMultipleDomains(urls, "TURN") -} diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go deleted file mode 100644 index 5eae3a541..000000000 --- a/client/internal/dns/config/domains_test.go +++ /dev/null @@ -1,213 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestExtractValidDomain(t *testing.T) { - tests := []struct { - name string - url string - expected string - expectError bool - }{ - { - name: "HTTPS URL with port", - url: "https://api.netbird.io:443", - expected: "api.netbird.io", - }, - { - name: "HTTP URL without port", - url: "http://signal.example.com", - expected: "signal.example.com", - }, - { - name: "Host with port (no scheme)", - url: "signal.netbird.io:443", - expected: "signal.netbird.io", - }, - { - name: "STUN URL", - url: "stun:stun.netbird.io:443", - expected: "stun.netbird.io", - }, - { - name: "STUN URL with different port", - url: "stun:stun.netbird.io:5555", - expected: "stun.netbird.io", - }, - { - name: "TURNS URL with query params", - url: "turns:turn.netbird.io:443?transport=tcp", - expected: "turn.netbird.io", - }, - { - name: "TURN URL", - url: "turn:turn.example.com:3478", - expected: "turn.example.com", - }, - { - name: "REL URL", - url: "rel://relay.example.com:443", - expected: "relay.example.com", - }, - { - name: "RELS URL", - url: "rels://relay.netbird.io:443", - expected: "relay.netbird.io", - }, - { - name: "Raw hostname", - url: "example.org", - expected: "example.org", - }, - { - name: "IP address should be rejected", - url: "192.168.1.1", - expectError: true, - }, - { - name: "IP address with port should be rejected", - url: "192.168.1.1:443", - expectError: true, - }, - { - name: "IPv6 address should be rejected", - url: "2001:db8::1", - expectError: true, - }, - { - name: "HTTP URL with IPv4 should be rejected", - url: "http://192.168.1.1:8080", - expectError: true, - }, - { - name: "HTTPS URL with IPv4 should be rejected", - url: "https://10.0.0.1:443", - expectError: true, - }, - { - name: "STUN URL with IPv4 should be rejected", - url: "stun:192.168.1.1:3478", - expectError: true, - }, - { - name: "TURN URL with IPv4 should be rejected", - url: "turn:10.0.0.1:3478", - expectError: true, - }, - { - name: "TURNS URL with IPv4 should be rejected", - url: "turns:172.16.0.1:5349", - expectError: true, - }, - { - name: "HTTP URL with IPv6 should be rejected", - url: "http://[2001:db8::1]:8080", - expectError: true, - }, - { - name: "HTTPS URL with IPv6 should be rejected", - url: "https://[::1]:443", - expectError: true, - }, - { - name: "STUN URL with IPv6 should be rejected", - url: "stun:[2001:db8::1]:3478", - expectError: true, - }, - { - name: "IPv6 with port should be rejected", - url: "[2001:db8::1]:443", - expectError: true, - }, - { - name: "Localhost IPv4 should be rejected", - url: "127.0.0.1:8080", - expectError: true, - }, - { - name: "Localhost IPv6 should be rejected", - url: "[::1]:443", - expectError: true, - }, - { - name: "REL URL with IPv4 should be rejected", - url: "rel://192.168.1.1:443", - expectError: true, - }, - { - name: "RELS URL with IPv4 should be rejected", - url: "rels://10.0.0.1:443", - expectError: true, - }, - { - name: "Empty URL", - url: "", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := ExtractValidDomain(tt.url) - - if tt.expectError { - assert.Error(t, err, "Expected error for URL: %s", tt.url) - } else { - assert.NoError(t, err, "Unexpected error for URL: %s", tt.url) - assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url) - } - }) - } -} - -func TestExtractDomainFromHost(t *testing.T) { - tests := []struct { - name string - host string - expected string - expectError bool - }{ - { - name: "Valid domain", - host: "example.com", - expected: "example.com", - }, - { - name: "Subdomain", - host: "api.example.com", - expected: "api.example.com", - }, - { - name: "IPv4 address", - host: "192.168.1.1", - expectError: true, - }, - { - name: "IPv6 address", - host: "2001:db8::1", - expectError: true, - }, - { - name: "Empty host", - host: "", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := extractDomainFromHost(tt.host) - - if tt.expectError { - assert.Error(t, err, "Expected error for host: %s", tt.host) - } else { - assert.NoError(t, err, "Unexpected error for host: %s", tt.host) - assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host) - } - }) - } -} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 2e54bffd9..439bcbb3c 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,12 +11,11 @@ import ( ) const ( - PriorityMgmtCache = 150 - PriorityLocal = 100 - PriorityDNSRoute = 75 - PriorityUpstream = 50 - PriorityDefault = 1 - PriorityFallback = -100 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -183,10 +182,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants to continue, try next handler if chainWriter.shouldContinue { - // Only log continue for non-management cache handlers to reduce noise - if entry.Priority != PriorityMgmtCache { - log.Tracef("handler requested continue to next handler for domain=%s", qname) - } + log.Tracef("handler requested continue to next handler for domain=%s", qname) continue } return diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..852dfef48 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -166,10 +166,9 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { - if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") - return fmt.Errorf("recordSystemDNSSettings(): %w", err) - } + err := s.recordSystemDNSSettings(true) + log.Errorf("Unable to get system DNS configuration") + return err } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 0d3f033fb..fdc2c3063 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,17 +240,15 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { + policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) + if r.gpo { + policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) + } singleDomain := []string{domain} - if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) - } - - if r.gpo { - if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure gpo DNS policy: %w", err) - } + if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) } log.Debugf("added NRPT entry for domain: %s", domain) @@ -403,7 +401,6 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) } - if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) } @@ -415,7 +412,6 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) } - if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) } diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index bac7875ec..b776fbbe3 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool { // String returns a string representation of the local resolver func (d *Resolver) String() string { - return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) + return fmt.Sprintf("local resolver [%d records]", len(d.records)) } func (d *Resolver) Stop() {} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go deleted file mode 100644 index 290395473..000000000 --- a/client/internal/dns/mgmt/mgmt.go +++ /dev/null @@ -1,360 +0,0 @@ -package mgmt - -import ( - "context" - "fmt" - "net" - "net/url" - "strings" - "sync" - "time" - - "github.com/miekg/dns" - log "github.com/sirupsen/logrus" - - dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" - "github.com/netbirdio/netbird/shared/management/domain" -) - -const dnsTimeout = 5 * time.Second - -// Resolver caches critical NetBird infrastructure domains -type Resolver struct { - records map[dns.Question][]dns.RR - mgmtDomain *domain.Domain - serverDomains *dnsconfig.ServerDomains - mutex sync.RWMutex -} - -// NewResolver creates a new management domains cache resolver. -func NewResolver() *Resolver { - return &Resolver{ - records: make(map[dns.Question][]dns.RR), - } -} - -// String returns a string representation of the resolver. -func (m *Resolver) String() string { - return "MgmtCacheResolver" -} - -// ServeDNS implements dns.Handler interface. -func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) == 0 { - m.continueToNext(w, r) - return - } - - question := r.Question[0] - question.Name = strings.ToLower(dns.Fqdn(question.Name)) - - if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { - m.continueToNext(w, r) - return - } - - m.mutex.RLock() - records, found := m.records[question] - m.mutex.RUnlock() - - if !found { - m.continueToNext(w, r) - return - } - - resp := &dns.Msg{} - resp.SetReply(r) - resp.Authoritative = false - resp.RecursionAvailable = true - - resp.Answer = append(resp.Answer, records...) - - log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) - - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write response: %v", err) - } -} - -// MatchSubdomains returns false since this resolver only handles exact domain matches -// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. -func (m *Resolver) MatchSubdomains() bool { - return false -} - -// continueToNext signals the handler chain to continue to the next handler. -func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { - resp := &dns.Msg{} - resp.SetRcode(r, dns.RcodeNameError) - resp.MsgHdr.Zero = true - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write continue signal: %v", err) - } -} - -// AddDomain manually adds a domain to cache by resolving it. -func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { - dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) - - ctx, cancel := context.WithTimeout(ctx, dnsTimeout) - defer cancel() - - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) - if err != nil { - return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) - } - - var aRecords, aaaaRecords []dns.RR - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) - } - } - - m.mutex.Lock() - - if len(aRecords) > 0 { - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - m.records[aQuestion] = aRecords - } - - if len(aaaaRecords) > 0 { - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - m.records[aaaaQuestion] = aaaaRecords - } - - m.mutex.Unlock() - - log.Debugf("added domain=%s with %d A records and %d AAAA records", - d.SafeString(), len(aRecords), len(aaaaRecords)) - - return nil -} - -// PopulateFromConfig extracts and caches domains from the client configuration. -func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { - if mgmtURL == nil { - return nil - } - - d, err := dnsconfig.ExtractValidDomain(mgmtURL.String()) - if err != nil { - return fmt.Errorf("extract domain from URL: %w", err) - } - - m.mutex.Lock() - m.mgmtDomain = &d - m.mutex.Unlock() - - if err := m.AddDomain(ctx, d); err != nil { - return fmt.Errorf("add domain: %w", err) - } - - return nil -} - -// RemoveDomain removes a domain from the cache. -func (m *Resolver) RemoveDomain(d domain.Domain) error { - dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) - - m.mutex.Lock() - defer m.mutex.Unlock() - - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - delete(m.records, aQuestion) - - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - delete(m.records, aaaaQuestion) - - log.Debugf("removed domain=%s from cache", d.SafeString()) - return nil -} - -// GetCachedDomains returns a list of all cached domains. -func (m *Resolver) GetCachedDomains() domain.List { - m.mutex.RLock() - defer m.mutex.RUnlock() - - domainSet := make(map[domain.Domain]struct{}) - for question := range m.records { - domainName := strings.TrimSuffix(question.Name, ".") - domainSet[domain.Domain(domainName)] = struct{}{} - } - - domains := make(domain.List, 0, len(domainSet)) - for d := range domainSet { - domains = append(domains, d) - } - - return domains -} - -// UpdateFromServerDomains updates the cache with server domains from network configuration. -// It merges new domains with existing ones, replacing entire domain types when updated. -// Empty updates are ignored to prevent clearing infrastructure domains during partial updates. -func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) { - newDomains := m.extractDomainsFromServerDomains(serverDomains) - var removedDomains domain.List - - if len(newDomains) > 0 { - m.mutex.Lock() - if m.serverDomains == nil { - m.serverDomains = &dnsconfig.ServerDomains{} - } - updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains) - m.serverDomains = &updatedServerDomains - m.mutex.Unlock() - - allDomains := m.extractDomainsFromServerDomains(updatedServerDomains) - currentDomains := m.GetCachedDomains() - removedDomains = m.removeStaleDomains(currentDomains, allDomains) - } - - m.addNewDomains(ctx, newDomains) - - return removedDomains, nil -} - -// removeStaleDomains removes cached domains not present in the target domain list. -// Management domains are preserved and never removed during server domain updates. -func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List { - var removedDomains domain.List - - for _, currentDomain := range currentDomains { - if m.isDomainInList(currentDomain, newDomains) { - continue - } - - if m.isManagementDomain(currentDomain) { - continue - } - - removedDomains = append(removedDomains, currentDomain) - if err := m.RemoveDomain(currentDomain); err != nil { - log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) - } - } - - return removedDomains -} - -// mergeServerDomains merges new server domains with existing ones. -// When a domain type is provided in the new domains, it completely replaces that type. -func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains { - merged := existing - - if incoming.Signal != "" { - merged.Signal = incoming.Signal - } - if len(incoming.Relay) > 0 { - merged.Relay = incoming.Relay - } - if incoming.Flow != "" { - merged.Flow = incoming.Flow - } - if len(incoming.Stuns) > 0 { - merged.Stuns = incoming.Stuns - } - if len(incoming.Turns) > 0 { - merged.Turns = incoming.Turns - } - - return merged -} - -// isDomainInList checks if domain exists in the list -func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool { - for _, d := range list { - if domain.SafeString() == d.SafeString() { - return true - } - } - return false -} - -// isManagementDomain checks if domain is the protected management domain -func (m *Resolver) isManagementDomain(domain domain.Domain) bool { - m.mutex.RLock() - defer m.mutex.RUnlock() - - return m.mgmtDomain != nil && domain == *m.mgmtDomain -} - -// addNewDomains resolves and caches all domains from the update -func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { - for _, newDomain := range newDomains { - if err := m.AddDomain(ctx, newDomain); err != nil { - log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) - } else { - log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) - } - } -} - -func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List { - var domains domain.List - - if serverDomains.Signal != "" { - domains = append(domains, serverDomains.Signal) - } - - for _, relay := range serverDomains.Relay { - if relay != "" { - domains = append(domains, relay) - } - } - - if serverDomains.Flow != "" { - domains = append(domains, serverDomains.Flow) - } - - for _, stun := range serverDomains.Stuns { - if stun != "" { - domains = append(domains, stun) - } - } - - for _, turn := range serverDomains.Turns { - if turn != "" { - domains = append(domains, turn) - } - } - - return domains -} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go deleted file mode 100644 index 99d289871..000000000 --- a/client/internal/dns/mgmt/mgmt_test.go +++ /dev/null @@ -1,416 +0,0 @@ -package mgmt - -import ( - "context" - "fmt" - "net/url" - "strings" - "testing" - - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - - dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" - "github.com/netbirdio/netbird/client/internal/dns/test" - "github.com/netbirdio/netbird/shared/management/domain" -) - -func TestResolver_NewResolver(t *testing.T) { - resolver := NewResolver() - - assert.NotNil(t, resolver) - assert.NotNil(t, resolver.records) - assert.False(t, resolver.MatchSubdomains()) -} - -func TestResolver_ExtractDomainFromURL(t *testing.T) { - tests := []struct { - name string - urlStr string - expectedDom string - expectError bool - }{ - { - name: "HTTPS URL with port", - urlStr: "https://api.netbird.io:443", - expectedDom: "api.netbird.io", - expectError: false, - }, - { - name: "HTTP URL without port", - urlStr: "http://signal.example.com", - expectedDom: "signal.example.com", - expectError: false, - }, - { - name: "URL with path", - urlStr: "https://relay.netbird.io/status", - expectedDom: "relay.netbird.io", - expectError: false, - }, - { - name: "Invalid URL", - urlStr: "not-a-valid-url", - expectedDom: "not-a-valid-url", - expectError: false, - }, - { - name: "Empty URL", - urlStr: "", - expectedDom: "", - expectError: true, - }, - { - name: "STUN URL", - urlStr: "stun:stun.example.com:3478", - expectedDom: "stun.example.com", - expectError: false, - }, - { - name: "TURN URL", - urlStr: "turn:turn.example.com:3478", - expectedDom: "turn.example.com", - expectError: false, - }, - { - name: "REL URL", - urlStr: "rel://relay.example.com:443", - expectedDom: "relay.example.com", - expectError: false, - }, - { - name: "RELS URL", - urlStr: "rels://relay.example.com:443", - expectedDom: "relay.example.com", - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var parsedURL *url.URL - var err error - - if tt.urlStr != "" { - parsedURL, err = url.Parse(tt.urlStr) - if err != nil && !tt.expectError { - t.Fatalf("Failed to parse URL: %v", err) - } - } - - domain, err := extractDomainFromURL(parsedURL) - - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expectedDom, domain.SafeString()) - } - }) - } -} - -func TestResolver_PopulateFromConfig(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - resolver := NewResolver() - - // Test with IP address - should return error since IP addresses are rejected - mgmtURL, _ := url.Parse("https://127.0.0.1") - - err := resolver.PopulateFromConfig(ctx, mgmtURL) - assert.Error(t, err) - assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed) - - // No domains should be cached when using IP addresses - domains := resolver.GetCachedDomains() - assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") -} - -func TestResolver_ServeDNS(t *testing.T) { - resolver := NewResolver() - ctx := context.Background() - - // Add a test domain to the cache - use example.org which is reserved for testing - testDomain, err := domain.FromString("example.org") - if err != nil { - t.Fatalf("Failed to create domain: %v", err) - } - err = resolver.AddDomain(ctx, testDomain) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - - // Test A record query for cached domain - t.Run("Cached domain A record", func(t *testing.T) { - var capturedMsg *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - capturedMsg = m - return nil - }, - } - - req := new(dns.Msg) - req.SetQuestion("example.org.", dns.TypeA) - - resolver.ServeDNS(mockWriter, req) - - assert.NotNil(t, capturedMsg) - assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) - assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer") - }) - - // Test uncached domain signals to continue to next handler - t.Run("Uncached domain signals continue to next handler", func(t *testing.T) { - var capturedMsg *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - capturedMsg = m - return nil - }, - } - - req := new(dns.Msg) - req.SetQuestion("unknown.example.com.", dns.TypeA) - - resolver.ServeDNS(mockWriter, req) - - assert.NotNil(t, capturedMsg) - assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) - // Zero flag set to true signals the handler chain to continue to next handler - assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler") - assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain") - }) - - // Test that subdomains of cached domains are NOT resolved - t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) { - var capturedMsg *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - capturedMsg = m - return nil - }, - } - - // Query for a subdomain of our cached domain - req := new(dns.Msg) - req.SetQuestion("sub.example.org.", dns.TypeA) - - resolver.ServeDNS(mockWriter, req) - - assert.NotNil(t, capturedMsg) - assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) - assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains") - assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains") - }) - - // Test case-insensitive matching - t.Run("Case-insensitive domain matching", func(t *testing.T) { - var capturedMsg *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - capturedMsg = m - return nil - }, - } - - // Query with different casing - req := new(dns.Msg) - req.SetQuestion("EXAMPLE.ORG.", dns.TypeA) - - resolver.ServeDNS(mockWriter, req) - - assert.NotNil(t, capturedMsg) - assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) - assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case") - }) -} - -func TestResolver_GetCachedDomains(t *testing.T) { - resolver := NewResolver() - ctx := context.Background() - - testDomain, err := domain.FromString("example.org") - if err != nil { - t.Fatalf("Failed to create domain: %v", err) - } - err = resolver.AddDomain(ctx, testDomain) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - - cachedDomains := resolver.GetCachedDomains() - - assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain") - assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original") - assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot") -} - -func TestResolver_ManagementDomainProtection(t *testing.T) { - resolver := NewResolver() - ctx := context.Background() - - mgmtURL, _ := url.Parse("https://example.org") - err := resolver.PopulateFromConfig(ctx, mgmtURL) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - - initialDomains := resolver.GetCachedDomains() - if len(initialDomains) == 0 { - t.Skip("Management domain failed to resolve, skipping test") - } - assert.Equal(t, 1, len(initialDomains), "Should have management domain cached") - assert.Equal(t, "example.org", initialDomains[0].SafeString()) - - serverDomains := dnsconfig.ServerDomains{ - Signal: "google.com", - Relay: []domain.Domain{"cloudflare.com"}, - } - - _, err = resolver.UpdateFromServerDomains(ctx, serverDomains) - if err != nil { - t.Logf("Server domains update failed: %v", err) - } - - finalDomains := resolver.GetCachedDomains() - - managementStillCached := false - for _, d := range finalDomains { - if d.SafeString() == "example.org" { - managementStillCached = true - break - } - } - assert.True(t, managementStillCached, "Management domain should never be removed") -} - -// extractDomainFromURL extracts a domain from a URL - test helper function -func extractDomainFromURL(u *url.URL) (domain.Domain, error) { - if u == nil { - return "", fmt.Errorf("URL is nil") - } - return dnsconfig.ExtractValidDomain(u.String()) -} - -func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) { - resolver := NewResolver() - ctx := context.Background() - - // Set up initial domains using resolvable domains - initialDomains := dnsconfig.ServerDomains{ - Signal: "example.org", - Stuns: []domain.Domain{"google.com"}, - Turns: []domain.Domain{"cloudflare.com"}, - } - - // Add initial domains - _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - - // Verify domains were added - cachedDomains := resolver.GetCachedDomains() - assert.Len(t, cachedDomains, 3) - - // Update with empty ServerDomains (simulating partial network map update) - emptyDomains := dnsconfig.ServerDomains{} - removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains) - assert.NoError(t, err) - - // Verify no domains were removed - assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty") - - // Verify all original domains are still cached - finalDomains := resolver.GetCachedDomains() - assert.Len(t, finalDomains, 3, "All original domains should still be cached") -} - -func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) { - resolver := NewResolver() - ctx := context.Background() - - // Set up initial complete domains using resolvable domains - initialDomains := dnsconfig.ServerDomains{ - Signal: "example.org", - Stuns: []domain.Domain{"google.com"}, - Turns: []domain.Domain{"cloudflare.com"}, - } - - // Add initial domains - _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - assert.Len(t, resolver.GetCachedDomains(), 3) - - // Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn) - partialDomains := dnsconfig.ServerDomains{ - Signal: "github.com", - } - removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - - // Should remove only the old signal domain - assert.Len(t, removedDomains, 1, "Should remove only the old signal domain") - assert.Equal(t, "example.org", removedDomains[0].SafeString()) - - finalDomains := resolver.GetCachedDomains() - assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains") - - domainStrings := make([]string, len(finalDomains)) - for i, d := range finalDomains { - domainStrings[i] = d.SafeString() - } - assert.Contains(t, domainStrings, "github.com") - assert.Contains(t, domainStrings, "google.com") - assert.Contains(t, domainStrings, "cloudflare.com") - assert.NotContains(t, domainStrings, "example.org") -} - -func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { - resolver := NewResolver() - ctx := context.Background() - - // Set up initial complete domains using resolvable domains - initialDomains := dnsconfig.ServerDomains{ - Signal: "example.org", - Stuns: []domain.Domain{"google.com"}, - Turns: []domain.Domain{"cloudflare.com"}, - } - - // Add initial domains - _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - assert.Len(t, resolver.GetCachedDomains(), 3) - - // Update with partial ServerDomains (only flow domain - new type, should preserve all existing) - partialDomains := dnsconfig.ServerDomains{ - Flow: "github.com", - } - removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) - if err != nil { - t.Skipf("Skipping test due to DNS resolution failure: %v", err) - } - - assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") - - finalDomains := resolver.GetCachedDomains() - assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") - - domainStrings := make([]string, len(finalDomains)) - for i, d := range finalDomains { - domainStrings[i] = d.SafeString() - } - assert.Contains(t, domainStrings, "example.org") - assert.Contains(t, domainStrings, "google.com") - assert.Contains(t, domainStrings, "cloudflare.com") - assert.Contains(t, domainStrings, "github.com") -} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0f89b9016..d160fa99a 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,23 +3,20 @@ package dns import ( "fmt" "net/netip" - "net/url" "github.com/miekg/dns" - dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/shared/management/domain" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func(domain.List, dns.Handler, int) - DeregisterHandlerFunc func(domain.List, int) - UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) } func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -73,14 +70,3 @@ func (m *MockServer) SearchDomains() []string { // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface func (m *MockServer) ProbeAvailability() { } - -func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { - if m.UpdateServerConfigFunc != nil { - return m.UpdateServerConfigFunc(domains) - } - return nil -} - -func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { - return nil -} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..cbcf6a256 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/netip" - "net/url" "runtime" "strings" "sync" @@ -16,9 +15,7 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" - dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" - "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -48,8 +45,6 @@ type Server interface { OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string ProbeAvailability() - UpdateServerConfig(domains dnsconfig.ServerDomains) error - PopulateManagementDomain(mgmtURL *url.URL) error } type nsGroupsByDomain struct { @@ -82,8 +77,6 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int - mgmtCacheResolver *mgmt.Resolver - // permanent related properties permanent bool hostsDNSHolder *hostsDNSHolder @@ -111,20 +104,18 @@ type handlerWrapper struct { type registeredHandlerMap map[types.HandlerID]handlerWrapper -// DefaultServerConfig holds configuration parameters for NewDefaultServer -type DefaultServerConfig struct { - WgInterface WGIface - CustomAddress string - StatusRecorder *peer.Status - StateManager *statemanager.Manager - DisableSys bool -} - // NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) { +func NewDefaultServer( + ctx context.Context, + wgInterface WGIface, + customAddress string, + statusRecorder *peer.Status, + stateManager *statemanager.Manager, + disableSys bool, +) (*DefaultServer, error) { var addrPort *netip.AddrPort - if config.CustomAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) + if customAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(customAddress) if err != nil { return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) } @@ -132,14 +123,13 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default } var dnsService service - if config.WgInterface.IsUserspaceBind() { - dnsService = NewServiceViaMemory(config.WgInterface) + if wgInterface.IsUserspaceBind() { + dnsService = NewServiceViaMemory(wgInterface) } else { - dnsService = newServiceViaListener(config.WgInterface, addrPort) + dnsService = newServiceViaListener(wgInterface, addrPort) } - server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) - return server, nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -188,24 +178,20 @@ func newDefaultServer( ) *DefaultServer { handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) - - mgmtCacheResolver := mgmt.NewResolver() - defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: local.NewResolver(), - wgInterface: wgInterface, - statusRecorder: statusRecorder, - stateManager: stateManager, - hostsDNSHolder: newHostsDNSHolder(), - hostManager: &noopHostConfigurator{}, - mgmtCacheResolver: mgmtCacheResolver, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), + wgInterface: wgInterface, + statusRecorder: statusRecorder, + stateManager: stateManager, + hostsDNSHolder: newHostsDNSHolder(), + hostManager: &noopHostConfigurator{}, } // register with root zone, handler chain takes care of the routing @@ -231,7 +217,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { - log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains) + log.Debugf("registering handler %s with priority %d", handler, priority) for _, domain := range domains { if domain == "" { @@ -260,7 +246,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { - log.Debugf("deregistering handler with priority %d for %v", priority, domains) + log.Debugf("deregistering handler %v with priority %d", domains, priority) for _, domain := range domains { if domain == "" { @@ -446,29 +432,6 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } -func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { - s.mux.Lock() - defer s.mux.Unlock() - - if s.mgmtCacheResolver != nil { - removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) - if err != nil { - return fmt.Errorf("update management cache resolver: %w", err) - } - - if len(removedDomains) > 0 { - s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) - } - - newDomains := s.mgmtCacheResolver.GetCachedDomains() - if len(newDomains) > 0 { - s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) - } - } - - return nil -} - func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver if update.ServiceEnable { @@ -998,11 +961,3 @@ func toZone(d domain.Domain) domain.Domain { ), ) } - -// PopulateManagementDomain populates the DNS cache with management domain -func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { - if s.mgmtCacheResolver != nil { - return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) - } - return nil -} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 11575d500..068f001d8 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -363,13 +363,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ - WgInterface: wgIface, - CustomAddress: "", - StatusRecorder: peer.NewRecorder("mgm"), - StateManager: nil, - DisableSys: false, - }) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) if err != nil { t.Fatal(err) } @@ -479,13 +473,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ - WgInterface: wgIface, - CustomAddress: "", - StatusRecorder: peer.NewRecorder("mgm"), - StateManager: nil, - DisableSys: false, - }) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -587,13 +575,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ - WgInterface: &mocWGIface{}, - CustomAddress: testCase.addrPort, - StatusRecorder: peer.NewRecorder("mgm"), - StateManager: nil, - DisableSys: false, - }) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 6ef0ab526..89d637686 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type ServiceViaMemory struct { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c19e0acb5..f5d0e775f 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -26,18 +26,10 @@ import ( "github.com/netbirdio/netbird/client/proto" ) -var currentMTU uint16 = iface.DefaultMTU - -func SetCurrentMTU(mtu uint16) { - currentMTU = mtu -} - const ( - UpstreamTimeout = 4 * time.Second - // ClientTimeout is the timeout for the dns.Client. - // Set longer than UpstreamTimeout to ensure context timeout takes precedence - ClientTimeout = 5 * time.Second + UpstreamTimeout = 15 * time.Second + failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second ) @@ -60,7 +52,9 @@ type upstreamResolverBase struct { upstreamServers []netip.AddrPort domain string disabled bool + failsCount atomic.Int32 successCount atomic.Int32 + failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration @@ -79,13 +73,14 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain: domain, upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, + failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("Upstream %s", u.upstreamServers) + return fmt.Sprintf("upstream %s", u.upstreamServers) } // ID returns the unique handler ID @@ -115,102 +110,58 @@ func (u *upstreamResolverBase) Stop() { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { requestID := GenerateRequestID() logger := log.WithField("request_id", requestID) + var err error + defer func() { + u.checkUpstreamFails(err) + }() logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) - - u.prepareRequest(r) - - if u.ctx.Err() != nil { - logger.Tracef("%s has been stopped", u) - return - } - - if u.tryUpstreamServers(w, r, logger) { - return - } - - u.writeErrorResponse(w, r, logger) -} - -func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } -} -func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { - timeout := u.upstreamTimeout - if len(u.upstreamServers) > 1 { - maxTotal := 5 * time.Second - minPerUpstream := 2 * time.Second - scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) - if scaledTimeout > minPerUpstream { - timeout = scaledTimeout - } else { - timeout = minPerUpstream - } + select { + case <-u.ctx.Done(): + logger.Tracef("%s has been stopped", u) + return + default: } for _, upstream := range u.upstreamServers { - if u.queryUpstream(w, r, upstream, timeout, logger) { - return true + var rm *dns.Msg + var t time.Duration + + func() { + ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) + defer cancel() + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + }() + + if err != nil { + if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { + logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) + continue + } + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) + continue } - } - return false -} -func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { - var rm *dns.Msg - var t time.Duration - var err error + if rm == nil || !rm.Response { + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + continue + } - var startTime time.Time - func() { - ctx, cancel := context.WithTimeout(u.ctx, timeout) - defer cancel() - startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) - }() + u.successCount.Add(1) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) - if err != nil { - u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) - return false - } - - if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - return false - } - - return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) -} - -func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { - if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) + if err = w.WriteMsg(rm); err != nil { + logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) + } + // count the fails only if they happen sequentially + u.failsCount.Store(0) return } - - elapsed := time.Since(startTime) - timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) - if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { - timeoutMsg += " " + peerInfo - } - timeoutMsg += fmt.Sprintf(" - error: %v", err) - logger.Warnf(timeoutMsg) -} - -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { - u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) - - if err := w.WriteMsg(rm); err != nil { - logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) - } - return true -} - -func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { + u.failsCount.Add(1) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) @@ -220,6 +171,41 @@ func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.M } } +// checkUpstreamFails counts fails and disables or enables upstream resolving +// +// If fails count is greater that failsTillDeact, upstream resolving +// will be disabled for reactivatePeriod, after that time period fails counter +// will be reset and upstream will be reactivated. +func (u *upstreamResolverBase) checkUpstreamFails(err error) { + u.mutex.Lock() + defer u.mutex.Unlock() + + if u.failsCount.Load() < u.failsTillDeact || u.disabled { + return + } + + select { + case <-u.ctx.Done(): + return + default: + } + + u.disable(err) + + if u.statusRecorder == nil { + return + } + + u.statusRecorder.PublishEvent( + proto.SystemEvent_WARNING, + proto.SystemEvent_DNS, + "All upstream servers failed (fail count exceeded)", + "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", + map[string]string{"upstreams": u.upstreamServersString()}, + // TODO add domain meta + ) +} + // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -232,8 +218,8 @@ func (u *upstreamResolverBase) ProbeAvailability() { default: } - // avoid probe if upstreams could resolve at least one query - if u.successCount.Load() > 0 { + // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact + if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { return } @@ -320,6 +306,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) + u.failsCount.Store(0) u.successCount.Add(1) u.reactivate() u.disabled = false @@ -371,8 +358,8 @@ func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout tim // If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { // MTU - ip + udp headers - // Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling. - client.UDPSize = uint16(currentMTU - (60 + 8)) + // Note: this could be sent out on an interface that is not ours, but our MTU should always be lower. + client.UDPSize = iface.DefaultMTU - (60 + 8) var ( rm *dns.Msg @@ -423,80 +410,3 @@ func GenerateRequestID() string { } return hex.EncodeToString(bytes) } - -// FormatPeerStatus formats peer connection status information for debugging DNS timeouts -func FormatPeerStatus(peerState *peer.State) string { - isConnected := peerState.ConnStatus == peer.StatusConnected - hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() && - time.Since(peerState.LastWireguardHandshake) < 3*time.Minute - - statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP) - - switch { - case !isConnected: - statusInfo += " DISCONNECTED" - case !hasRecentHandshake: - statusInfo += " NO_RECENT_HANDSHAKE" - default: - statusInfo += " connected" - } - - if !peerState.LastWireguardHandshake.IsZero() { - timeSinceHandshake := time.Since(peerState.LastWireguardHandshake) - statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second)) - } else { - statusInfo += " no_handshake" - } - - if peerState.Relayed { - statusInfo += " via_relay" - } - - if peerState.Latency > 0 { - statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency) - } - - return statusInfo -} - -// findPeerForIP finds which peer handles the given IP address -func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { - if statusRecorder == nil { - return nil - } - - fullStatus := statusRecorder.GetFullStatus() - var bestMatch *peer.State - var bestPrefixLen int - - for _, peerState := range fullStatus.Peers { - routes := peerState.GetRoutes() - for route := range routes { - prefix, err := netip.ParsePrefix(route) - if err != nil { - continue - } - - if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen { - peerStateCopy := peerState - bestMatch = &peerStateCopy - bestPrefixLen = prefix.Bits() - } - } - } - - return bestMatch -} - -func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { - if u.statusRecorder == nil { - return "" - } - - peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) - if peerInfo == nil { - return "" - } - - return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) -} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index def281f28..ddbf84ae4 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type upstreamResolver struct { @@ -50,9 +50,7 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns } func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{ - Timeout: ClientTimeout, - } + upstreamExchangeClient := &dns.Client{} return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } @@ -74,11 +72,10 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri } upstreamExchangeClient := &dns.Client{ - Dialer: dialer, - Timeout: timeout, + Dialer: dialer, } - return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) + return upstreamExchangeClient.Exchange(r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 434e5880b..317588a27 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -34,10 +34,7 @@ func newUpstreamResolver( } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - client := &dns.Client{ - Timeout: ClientTimeout, - } - return ExchangeWithFallback(ctx, client, r, upstream) + return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) } func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index eadcdd117..96b8bbb0f 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -47,9 +47,7 @@ func newUpstreamResolver( } func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - client := &dns.Client{ - Timeout: ClientTimeout, - } + client := &dns.Client{} upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) @@ -112,8 +110,7 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura }, } client := &dns.Client{ - Dialer: dialer, - Timeout: dialTimeout, + Dialer: dialer, } return client, nil } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index e1573e75e..51d870e2a 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -124,26 +124,29 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) } func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - mockClient := &mockUpstreamResolver{ - err: dns.ErrTime, - r: new(dns.Msg), - rtt: time.Millisecond, - } - resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: mockClient, + ctx: context.TODO(), + upstreamClient: &mockUpstreamResolver{ + err: nil, + r: new(dns.Msg), + rtt: time.Millisecond, + }, upstreamTimeout: UpstreamTimeout, - reactivatePeriod: time.Microsecond * 100, + reactivatePeriod: reactivatePeriod, + failsTillDeact: failsTillDeact, } addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} + resolver.failsTillDeact = 0 + resolver.reactivatePeriod = time.Microsecond * 100 + + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { return nil }, + } failed := false resolver.deactivate = func(error) { failed = true - // After deactivation, make the mock client work again - mockClient.err = nil } reactivated := false @@ -151,7 +154,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ProbeAvailability() + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) if !failed { t.Errorf("expected that resolving was deactivated") @@ -170,6 +173,11 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { return } + if resolver.failsCount.Load() != 0 { + t.Errorf("fails count after reactivation should be 0") + return + } + if resolver.disabled { t.Errorf("should be enabled") } diff --git a/client/internal/engine.go b/client/internal/engine.go index fefe2e96c..4e847758d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,7 +7,6 @@ import ( "math/rand" "net" "net/netip" - "net/url" "os" "reflect" "runtime" @@ -18,8 +17,8 @@ import ( "time" "github.com/hashicorp/go-multierror" - "github.com/pion/ice/v4" - "github.com/pion/stun/v3" + "github.com/pion/ice/v3" + "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -29,13 +28,12 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" - dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -136,7 +134,6 @@ type EngineConfig struct { ProfileConfig *profilemanager.Config LogFile string - MTU uint16 } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -174,7 +171,7 @@ type Engine struct { wgInterface WGIface - udpMux *udpmux.UniversalUDPMuxDefault + udpMux *bind.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -210,10 +207,6 @@ type Engine struct { jobExecutor *jobexec.Executor jobExecutorWG sync.WaitGroup - - // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -350,23 +343,16 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() - return nil } // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service -func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { +func (e *Engine) Start() error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - if err := iface.ValidateMTU(e.config.MTU); err != nil { - return fmt.Errorf("invalid MTU configuration: %w", err) - } - if e.cancel != nil { e.cancel() } @@ -415,11 +401,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) } e.dnsServer = dnsServer - // Populate DNS cache with NetbirdConfig and management URL for early resolution - if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil { - log.Warnf("failed to populate DNS cache: %v", err) - } - e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ Context: e.ctx, PublicKey: e.config.WgPrivateKey.PublicKey().String(), @@ -458,8 +439,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } - - // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) @@ -475,7 +454,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, + UDPMux: e.udpMux.UDPMuxDefault, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), } @@ -492,22 +471,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() - - // monitor WireGuard interface lifecycle and restart engine on changes - e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) - - go func() { - defer e.wgIfaceMonitorWg.Done() - - if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { - log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() - } else if err != nil { - log.Warnf("WireGuard interface monitor: %s", err) - } - }() - return nil } @@ -699,30 +662,6 @@ func (e *Engine) removePeer(peerKey string) error { return nil } -// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response -func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { - if e.dnsServer == nil { - return nil - } - - // Populate management URL if provided - if mgmtURL != nil { - if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil { - log.Warnf("failed to populate DNS cache with management URL: %v", err) - } - } - - // Populate NetbirdConfig domains if provided - if netbirdConfig != nil { - serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig) - if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { - return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err) - } - } - - return nil -} - func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -754,10 +693,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } - if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { - log.Warnf("Failed to update DNS server config: %v", err) - } - // todo update signal } @@ -1066,6 +1001,7 @@ func (e *Engine) receiveManagementEvents() { e.config.LazyConnectionEnabled, ) + // err = e.mgmClient.Sync(info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. @@ -1076,7 +1012,7 @@ func (e *Engine) receiveManagementEvents() { } log.Debugf("stopped receiving updates from Management Service") }() - log.Infof("connecting to Management Service updates stream") + log.Debugf("connecting to Management Service updates stream") } func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { @@ -1268,16 +1204,15 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { } convertedRoute := &route.Route{ - ID: route.ID(protoRoute.ID), - Network: prefix.Masked(), - Domains: domain.FromPunycodeList(protoRoute.Domains), - NetID: route.NetID(protoRoute.NetID), - NetworkType: route.NetworkType(protoRoute.NetworkType), - Peer: protoRoute.Peer, - Metric: int(protoRoute.Metric), - Masquerade: protoRoute.Masquerade, - KeepRoute: protoRoute.KeepRoute, - SkipAutoApply: protoRoute.SkipAutoApply, + ID: route.ID(protoRoute.ID), + Network: prefix.Masked(), + Domains: domain.FromPunycodeList(protoRoute.Domains), + NetID: route.NetID(protoRoute.NetID), + NetworkType: route.NetworkType(protoRoute.NetworkType), + Peer: protoRoute.Peer, + Metric: int(protoRoute.Metric), + Masquerade: protoRoute.Masquerade, + KeepRoute: protoRoute.KeepRoute, } routes = append(routes, convertedRoute) } @@ -1443,7 +1378,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, + UDPMux: e.udpMux.UDPMuxDefault, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), }, @@ -1649,7 +1584,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { Address: e.config.WgAddr, WGPort: e.config.WgPort, WGPrivKey: e.config.WgPrivateKey.String(), - MTU: e.config.MTU, + MTU: iface.DefaultMTU, TransportNet: transportNet, FilterFn: e.addrViaRoutes, DisableDNS: e.config.DisableDNS, @@ -1708,14 +1643,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil default: - - dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ - WgInterface: e.wgInterface, - CustomAddress: e.config.CustomDNSAddress, - StatusRecorder: e.statusRecorder, - StateManager: e.stateManager, - DisableSys: e.config.DisableDNS, - }) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index aeeb68e79..1a179c6ce 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -19,22 +19,22 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" @@ -46,12 +46,9 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -89,7 +86,7 @@ type MockWGIface struct { NameFunc func() string AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface - UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) + UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error @@ -139,7 +136,7 @@ func (m *MockWGIface) ToInterface() *net.Interface { return m.ToInterfaceFunc() } -func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { +func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { return m.UpFunc() } @@ -222,25 +219,14 @@ func TestEngine_SSH(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - ServerSSHAllowed: true, - MTU: iface.DefaultMTU, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil, - ) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun101", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + ServerSSHAllowed: true, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -271,7 +257,7 @@ func TestEngine_SSH(t *testing.T) { }, }, nil } - err = engine.Start(nil, nil) + err = engine.Start() if err != nil { t.Fatal(err) } @@ -369,23 +355,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - MTU: iface.DefaultMTU, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -420,7 +396,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) + engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) @@ -597,14 +573,13 @@ func TestEngine_Sync(t *testing.T) { } return nil } - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -618,7 +593,7 @@ func TestEngine_Sync(t *testing.T) { } }() - err = engine.Start(nil, nil) + err = engine.Start() if err != nil { t.Fatal(err) return @@ -762,14 +737,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet() if err != nil { @@ -964,14 +938,13 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - MTU: iface.DefaultMTU, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet() @@ -1075,7 +1048,7 @@ func TestEngine_MultiplePeers(t *testing.T) { defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - err = engine.Start(nil, nil) + err = engine.Start() if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) wg.Done() @@ -1192,7 +1165,6 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { config: &EngineConfig{ IFaceBlackList: testCase.inputBlacklistInterface, NATExternalIPs: testCase.inputMapList, - MTU: iface.DefaultMTU, }, } parsedList := engine.parseNATExternalIPMappings() @@ -1493,12 +1465,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin WgAddr: resp.PeerConfig.Address, WgPrivateKey: key, WgPort: wgPort, - MTU: iface.DefaultMTU, } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil - e.ctx = ctx return e, err } @@ -1563,11 +1533,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - - permissionsManager := permissions.NewManager(store) - peersManager := peers.NewManager(store, permissionsManager) - - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1584,6 +1550,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri Return(&types.ExtraSettings{}, nil). AnyTimes() + permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 690fdb7cc..bf96153ea 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -9,9 +9,9 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -24,7 +24,7 @@ type wgIfaceBase interface { Name() string Address() wgaddr.Address ToInterface() *net.Interface - Up() (*udpmux.UniversalUDPMuxDefault, error) + Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/login.go b/client/internal/login.go index 257e3c3ac..d5412a110 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, return false, err } - _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) + _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) if isLoginNeeded(err) { return true, nil } @@ -69,18 +69,14 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string, return err } - serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) + serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) if serverKey != nil && isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) - if err != nil { - return err - } - } else if err != nil { return err } - return nil + return err } func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { @@ -105,11 +101,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) - return nil, nil, err + return nil, err } sysInfo := system.GetInfo(ctx) @@ -125,8 +121,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockInbound, config.LazyConnectionEnabled, ) - loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, loginResp, err + _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) + return serverKey, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index a4ffa3a25..dbb4747a5 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,7 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) const defaultChannelSize = 100 diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8db9e58f4..a6cf3cd25 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -6,11 +6,12 @@ import ( "math/rand" "net" "net/netip" + "os" "runtime" "sync" "time" - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -28,6 +29,10 @@ import ( semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) +const ( + defaultWgKeepAlive = 25 * time.Second +) + type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -113,8 +118,6 @@ type Conn struct { // debug purpose dumpState *stateDump - - endpointUpdater *EndpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -127,18 +130,17 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), - endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), } return conn, nil @@ -172,7 +174,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) - if !isForceRelayed() { + if os.Getenv("NB_FORCE_RELAY") != "true" { conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } @@ -248,7 +250,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.endpointUpdater.RemoveWgPeer(); err != nil { + if err := conn.removeWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -374,19 +376,12 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn wgProxy.Work() } - conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) - presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) - if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { + if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - if conn.wgProxyRelay != nil { - conn.Log.Debugf("redirect packets from relayed conn to WireGuard") - conn.wgProxyRelay.RedirectAs(ep) - } - conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -415,8 +410,7 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) - if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -425,7 +419,6 @@ func (conn *Conn) onICEStateDisconnected() { defer conn.wgWatcherWg.Done() conn.workerRelay.EnableWgWatcher(conn.ctx) }() - conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) @@ -485,8 +478,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - presharedKey := conn.presharedKey(rci.rosenpassPubKey) - if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -554,6 +546,17 @@ func (conn *Conn) onGuardEvent() { } } +func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { + presharedKey := conn.presharedKey(remoteRPKey) + return conn.config.WgConfig.WgInterface.UpdatePeer( + conn.config.WgConfig.RemoteKey, + conn.config.WgConfig.AllowedIps, + defaultWgKeepAlive, + addr, + presharedKey, + ) +} + func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -696,6 +699,10 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go deleted file mode 100644 index 39cb95591..000000000 --- a/client/internal/peer/endpoint.go +++ /dev/null @@ -1,105 +0,0 @@ -package peer - -import ( - "context" - "net" - "sync" - "time" - - "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -const ( - defaultWgKeepAlive = 25 * time.Second - fallbackDelay = 5 * time.Second -) - -type EndpointUpdater struct { - log *logrus.Entry - wgConfig WgConfig - initiator bool - - // mu protects updateWireGuardPeer and cancelFunc - mu sync.Mutex - cancelFunc func() - updateWg sync.WaitGroup -} - -func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater { - return &EndpointUpdater{ - log: log, - wgConfig: wgConfig, - initiator: initiator, - } -} - -// ConfigureWGEndpoint sets up the WireGuard endpoint configuration. -// The initiator immediately configures the endpoint, while the non-initiator -// waits for a fallback period before configuring to avoid handshake congestion. -func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { - e.mu.Lock() - defer e.mu.Unlock() - - if e.initiator { - e.log.Debugf("configure up WireGuard as initiatr") - return e.updateWireGuardPeer(addr, presharedKey) - } - - // prevent to run new update while cancel the previous update - e.waitForCloseTheDelayedUpdate() - - var ctx context.Context - ctx, e.cancelFunc = context.WithCancel(context.Background()) - e.updateWg.Add(1) - go e.scheduleDelayedUpdate(ctx, addr, presharedKey) - - e.log.Debugf("configure up WireGuard and wait for handshake") - return e.updateWireGuardPeer(nil, presharedKey) -} - -func (e *EndpointUpdater) RemoveWgPeer() error { - e.mu.Lock() - defer e.mu.Unlock() - - e.waitForCloseTheDelayedUpdate() - return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) -} - -func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { - if e.cancelFunc == nil { - return - } - - e.cancelFunc() - e.cancelFunc = nil - e.updateWg.Wait() -} - -// scheduleDelayedUpdate waits for the fallback period before updating the endpoint -func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) { - defer e.updateWg.Done() - t := time.NewTimer(fallbackDelay) - defer t.Stop() - - select { - case <-ctx.Done(): - return - case <-t.C: - e.mu.Lock() - if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { - e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) - } - e.mu.Unlock() - } -} - -func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error { - return e.wgConfig.WgInterface.UpdatePeer( - e.wgConfig.RemoteKey, - e.wgConfig.AllowedIps, - defaultWgKeepAlive, - endpoint, - presharedKey, - ) -} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go deleted file mode 100644 index 32a458d00..000000000 --- a/client/internal/peer/env.go +++ /dev/null @@ -1,14 +0,0 @@ -package peer - -import ( - "os" - "strings" -) - -const ( - EnvKeyNBForceRelay = "NB_FORCE_RELAY" -) - -func isForceRelayed() bool { - return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") -} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 70850e6eb..b9c9aa134 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 42eaea683..3cbf74cfd 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -43,6 +43,13 @@ type OfferAnswer struct { SessionID *ICESessionID } +func (oa *OfferAnswer) SessionIDString() string { + if oa.SessionID == nil { + return "unknown" + } + return oa.SessionID.String() +} + type Handshaker struct { mu sync.Mutex log *log.Entry @@ -50,7 +57,7 @@ type Handshaker struct { signaler *Signaler ice *WorkerICE relay *WorkerRelay - onNewOfferListeners []*OfferListener + onNewOfferListeners []func(*OfferAnswer) // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -71,8 +78,7 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - l := NewOfferListener(offer) - h.onNewOfferListeners = append(h.onNewOfferListeners, l) + h.onNewOfferListeners = append(h.onNewOfferListeners, offer) } func (h *Handshaker) Listen(ctx context.Context) { @@ -85,13 +91,13 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) + listener(&remoteOfferAnswer) } h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) + listener(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go deleted file mode 100644 index e2d3f3f38..000000000 --- a/client/internal/peer/handshaker_listener.go +++ /dev/null @@ -1,62 +0,0 @@ -package peer - -import ( - "sync" -) - -type callbackFunc func(remoteOfferAnswer *OfferAnswer) - -func (oa *OfferAnswer) SessionIDString() string { - if oa.SessionID == nil { - return "unknown" - } - return oa.SessionID.String() -} - -type OfferListener struct { - fn callbackFunc - running bool - latest *OfferAnswer - mu sync.Mutex -} - -func NewOfferListener(fn callbackFunc) *OfferListener { - return &OfferListener{ - fn: fn, - } -} - -func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { - o.mu.Lock() - defer o.mu.Unlock() - - // Store the latest offer - o.latest = remoteOfferAnswer - - // If already running, the running goroutine will pick up this latest value - if o.running { - return - } - - // Start processing - o.running = true - - // Process in a goroutine to avoid blocking the caller - go func(remoteOfferAnswer *OfferAnswer) { - for { - o.fn(remoteOfferAnswer) - - o.mu.Lock() - if o.latest == nil { - // No more work to do - o.running = false - o.mu.Unlock() - return - } - remoteOfferAnswer = o.latest - // Clear the latest to mark it as being processed - o.latest = nil - o.mu.Unlock() - } - }(remoteOfferAnswer) -} diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go deleted file mode 100644 index 8363741a5..000000000 --- a/client/internal/peer/handshaker_listener_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package peer - -import ( - "testing" - "time" -) - -func Test_newOfferListener(t *testing.T) { - dummyOfferAnswer := &OfferAnswer{} - runChan := make(chan struct{}, 10) - - longRunningFn := func(remoteOfferAnswer *OfferAnswer) { - time.Sleep(1 * time.Second) - runChan <- struct{}{} - } - - hl := NewOfferListener(longRunningFn) - - hl.Notify(dummyOfferAnswer) - hl.Notify(dummyOfferAnswer) - hl.Notify(dummyOfferAnswer) - - // Wait for exactly 2 callbacks - for i := 0; i < 2; i++ { - select { - case <-runChan: - case <-time.After(3 * time.Second): - t.Fatal("Timeout waiting for callback") - } - } - - // Verify no additional callbacks happen - select { - case <-runChan: - t.Fatal("Unexpected additional callback") - case <-time.After(100 * time.Millisecond): - t.Log("Correctly received exactly 2 callbacks") - } -} diff --git a/client/internal/peer/ice/StunTurn.go b/client/internal/peer/ice/StunTurn.go index a389f5444..63ee8c713 100644 --- a/client/internal/peer/ice/StunTurn.go +++ b/client/internal/peer/ice/StunTurn.go @@ -3,7 +3,7 @@ package ice import ( "sync/atomic" - "github.com/pion/stun/v3" + "github.com/pion/stun/v2" ) type StunTurn atomic.Value diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index e80c98884..4a0228405 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,10 +1,9 @@ package ice import ( - "sync" "time" - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" "github.com/pion/logging" "github.com/pion/randutil" log "github.com/sirupsen/logrus" @@ -24,20 +23,7 @@ const ( iceRelayAcceptanceMinWaitDefault = 2 * time.Second ) -type ThreadSafeAgent struct { - *ice.Agent - once sync.Once -} - -func (a *ThreadSafeAgent) Close() error { - var err error - a.once.Do(func() { - err = a.Agent.Close() - }) - return err -} - -func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { +func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() iceFailedTimeout := iceFailedTimeout() @@ -75,12 +61,7 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} } - agent, err := ice.NewAgent(agentConfig) - if err != nil { - return nil, err - } - - return &ThreadSafeAgent{Agent: agent}, nil + return ice.NewAgent(agentConfig) } func GenerateICECredentials() (string, string, error) { diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go index dd5d67403..dd854a605 100644 --- a/client/internal/peer/ice/config.go +++ b/client/internal/peer/ice/config.go @@ -1,7 +1,7 @@ package ice import ( - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" ) type Config struct { diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index b28906625..ca1d421a5 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -1,7 +1,7 @@ package peer import ( - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 0ed200fda..218872c15 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,10 +30,9 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex - enabledTime time.Time + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -49,7 +48,6 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() - w.enabledTime = time.Now() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") @@ -103,11 +101,6 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex onDisconnectedFn() return } - if lastHandshake.IsZero() { - elapsed := handshake.Sub(w.enabledTime).Seconds() - w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) - } - lastHandshake = *handshake resetTime := time.Until(handshake.Add(checkPeriod)) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index eb886a4d3..ee85254fb 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -8,11 +8,12 @@ import ( "sync" "time" - "github.com/pion/ice/v4" + "github.com/pion/ice/v3" + "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -41,7 +42,7 @@ type WorkerICE struct { statusRecorder *Status hasRelayOnLocally bool - agent *icemaker.ThreadSafeAgent + agent *ice.Agent agentDialerCancel context.CancelFunc agentConnecting bool // while it is true, drop all incoming offers lastSuccess time.Time // with this avoid the too frequent ICE agent recreation @@ -54,6 +55,10 @@ type WorkerICE struct { sessionID ICESessionID muxAgent sync.Mutex + StunTurn []*stun.URI + + sentExtraSrflx bool + localUfrag string localPwd string @@ -116,7 +121,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } - w.agent = nil + // todo consider to switch to Relay connection while establishing a new ICE connection } var preferredCandidateTypes []ice.CandidateType @@ -134,6 +139,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.muxAgent.Unlock() return } + w.sentExtraSrflx = false w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true @@ -160,21 +166,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA w.log.Errorf("error while handling remote candidate") return } - - if shouldAddExtraCandidate(candidate) { - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - w.log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - - if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil { - w.log.Errorf("error while handling remote candidate") - return - } - } } func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { @@ -204,7 +195,7 @@ func (w *WorkerICE) Close() { w.agent = nil } -func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { +func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) @@ -218,12 +209,14 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) { - w.onICESelectedCandidatePair(agent, c1, c2) - }); err != nil { + if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { return nil, err } + if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil { + return nil, fmt.Errorf("failed setting binding response callback: %w", err) + } + return agent, nil } @@ -237,7 +230,7 @@ func (w *WorkerICE) SessionID() ICESessionID { // will block until connection succeeded // but it won't release if ICE Agent went into Disconnected or Failed state, // so we have to cancel it with the provided context once agent detected a broken connection -func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) { +func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAnswer *OfferAnswer) { w.log.Debugf("gather candidates") if err := agent.GatherCandidates(); err != nil { w.log.Warnf("failed to gather candidates: %s", err) @@ -246,7 +239,7 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent } w.log.Debugf("turn agent dial") - remoteConn, err := w.turnAgentDial(ctx, agent, remoteOfferAnswer) + remoteConn, err := w.turnAgentDial(ctx, remoteOfferAnswer) if err != nil { w.log.Debugf("failed to dial the remote peer: %s", err) w.closeAgent(agent, w.agentDialerCancel) @@ -259,11 +252,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.closeAgent(agent, w.agentDialerCancel) return } - if pair == nil { - w.log.Warnf("selected candidate pair is nil, cannot proceed") - w.closeAgent(agent, w.agentDialerCancel) - return - } if !isRelayCandidate(pair.Local) { // dynamically set remote WireGuard port if other side specified a different one from the default one @@ -302,14 +290,13 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) } -func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) { +func (w *WorkerICE) closeAgent(agent *ice.Agent, cancel context.CancelFunc) { cancel() if err := agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } w.muxAgent.Lock() - // todo review does it make sense to generate new session ID all the time when w.agent==agent sessionID, err := NewICESessionID() if err != nil { w.log.Errorf("failed to create new session ID: %s", err) @@ -338,7 +325,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) return } - mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault) + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) if !ok { w.log.Warn("invalid udp mux conversion") return @@ -365,36 +352,41 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() + + if !w.shouldSendExtraSrflxCandidate(candidate) { + return + } + + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + w.sentExtraSrflx = true + + go func() { + err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) + if err != nil { + w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) + } + }() } -func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { +func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) - - pairStat, ok := agent.GetSelectedCandidatePairStats() - if !ok { - w.log.Warnf("failed to get selected candidate pair stats") - return - } - - duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second)) - if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { - w.log.Debugf("failed to update latency for peer: %s", err) - return - } } -func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { +func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { return func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) switch state { case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected return - case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: - // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to - // notify the conn.onICEStateDisconnected changes to update the current used priority - + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() @@ -406,34 +398,32 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia } } -func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { - if isController(w.config) { - return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) - } else { - return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) +func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) { + if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil { + w.log.Debugf("failed to update latency for peer: %s", err) + return } } -func shouldAddExtraCandidate(candidate ice.Candidate) bool { - if candidate.Type() != ice.CandidateTypeServerReflexive { - return false +func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { + if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { + return true } + return false +} - if candidate.Port() == candidate.RelatedAddress().Port { - return false +func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { + isControlling := w.config.LocalKey > w.config.Key + if isControlling { + return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + } else { + return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } - - // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates - // in newer version we generate locally the extra candidate - if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok { - return false - } - return true } func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() - ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ Network: candidate.NetworkType().String(), Address: candidate.Address(), Port: relatedAdd.Port, @@ -441,21 +431,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive RelAddr: relatedAdd.Address, RelPort: relatedAdd.Port, }) - if err != nil { - return nil, err - } - - for _, e := range candidate.Extensions() { - // overwrite the original candidate ID with the new one to avoid candidate duplication - if e.Key == ice.ExtensionKeyCandidateID { - e.Value = candidate.ID() - } - if err := ec.AddExtension(e); err != nil { - return nil, err - } - } - - return ec, nil } func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..6bbdbd984 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -75,8 +75,6 @@ type ConfigInput struct { DNSLabels domain.List LazyConnectionEnabled *bool - - MTU *uint16 } // Config Configuration type @@ -143,8 +141,6 @@ type Config struct { ClientCertKeyPair *tls.Certificate `json:"-"` LazyConnectionEnabled bool - - MTU uint16 } var ConfigDirOverride string @@ -497,16 +493,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { updated = true } - if input.MTU != nil && *input.MTU != config.MTU { - log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU) - config.MTU = *input.MTU - updated = true - } else if config.MTU == 0 { - config.MTU = iface.DefaultMTU - log.Infof("using default MTU %d", config.MTU) - updated = true - } - return updated, nil } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..6e1f83a9a 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -7,12 +7,12 @@ import ( "sync" "time" - "github.com/pion/stun/v3" + "github.com/pion/stun/v2" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // ProbeResult holds the info about the result of a relay probe request diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 9069cdcc5..ba27df654 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -2,13 +2,11 @@ package dnsinterceptor import ( "context" - "errors" "fmt" "net/netip" "runtime" "strings" "sync" - "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -28,8 +26,6 @@ import ( "github.com/netbirdio/netbird/route" ) -const dnsTimeout = 8 * time.Second - type domainMap map[domain.Domain][]netip.Prefix type internalDNATer interface { @@ -247,7 +243,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) if err != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return @@ -258,20 +254,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) - defer cancel() - - startTime := time.Now() - reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) + reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - elapsed := time.Since(startTime) - peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) - logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", - elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) - } else { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) - } + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { logger.Errorf("failed writing DNS response: %v", err) } @@ -583,16 +568,3 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR } return } - -func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { - if d.statusRecorder == nil { - return "" - } - - peerState, err := d.statusRecorder.GetPeer(peerKey) - if err != nil { - return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err) - } - - return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState)) -} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..da5534902 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,9 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" - "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" + "github.com/netbirdio/netbird/route" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -108,10 +108,6 @@ func NewManager(config ManagerConfig) *DefaultManager { notifier := notifier.NewNotifier() sysOps := systemops.NewSysOps(config.WGInterface, notifier) - if runtime.GOOS == "windows" && config.WGInterface != nil { - nbnet.SetVPNInterfaceName(config.WGInterface.Name()) - } - dm := &DefaultManager{ ctx: mCTX, stop: cancel, @@ -212,7 +208,7 @@ func (m *DefaultManager) Init() error { return nil } - if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil { + if err := m.sysOps.CleanupRouting(nil); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -223,7 +219,7 @@ func (m *DefaultManager) Init() error { ips := resolveURLsToIPs(initialAddresses) - if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil { + if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { return fmt.Errorf("setup routing: %w", err) } @@ -289,15 +285,11 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { - if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil { + if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") } - - if runtime.GOOS == "windows" { - nbnet.SetVPNInterfaceName("") - } } m.mux.Lock() @@ -376,11 +368,7 @@ func (m *DefaultManager) UpdateRoutes( var merr *multierror.Error if !m.disableClientRoutes { - - // Update route selector based on management server's isSelected status - m.updateRouteSelectorFromManagement(clientRoutes) - - filteredClientRoutes := m.routeSelector.FilterSelectedExitNodes(clientRoutes) + filteredClientRoutes := m.routeSelector.FilterSelected(clientRoutes) if err := m.updateSystemRoutes(filteredClientRoutes); err != nil { merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err)) @@ -442,7 +430,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.mux.Lock() defer m.mux.Unlock() - networks = m.routeSelector.FilterSelectedExitNodes(networks) + networks = m.routeSelector.FilterSelected(networks) m.notifier.OnNewRoutes(networks) @@ -595,106 +583,3 @@ func resolveURLsToIPs(urls []string) []net.IP { } return ips } - -// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server -func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) { - exitNodeInfo := m.collectExitNodeInfo(clientRoutes) - if len(exitNodeInfo.allIDs) == 0 { - return - } - - m.updateExitNodeSelections(exitNodeInfo) - m.logExitNodeUpdate(exitNodeInfo) -} - -type exitNodeInfo struct { - allIDs []route.NetID - selectedByManagement []route.NetID - userSelected []route.NetID - userDeselected []route.NetID -} - -func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo { - var info exitNodeInfo - - for haID, routes := range clientRoutes { - if !m.isExitNodeRoute(routes) { - continue - } - - netID := haID.NetID() - info.allIDs = append(info.allIDs, netID) - - if m.routeSelector.HasUserSelectionForRoute(netID) { - m.categorizeUserSelection(netID, &info) - } else { - m.checkManagementSelection(routes, netID, &info) - } - } - - return info -} - -func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool { - return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR -} - -func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) { - if m.routeSelector.IsSelected(netID) { - info.userSelected = append(info.userSelected, netID) - } else { - info.userDeselected = append(info.userDeselected, netID) - } -} - -func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID route.NetID, info *exitNodeInfo) { - for _, route := range routes { - if !route.SkipAutoApply { - info.selectedByManagement = append(info.selectedByManagement, netID) - break - } - } -} - -func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) { - routesToDeselect := m.getRoutesToDeselect(info.allIDs) - m.deselectExitNodes(routesToDeselect) - m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs) -} - -func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID { - var routesToDeselect []route.NetID - for _, netID := range allIDs { - if !m.routeSelector.HasUserSelectionForRoute(netID) { - routesToDeselect = append(routesToDeselect, netID) - } - } - return routesToDeselect -} - -func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) { - if len(routesToDeselect) == 0 { - return - } - - err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect) - if err != nil { - log.Warnf("Failed to deselect exit nodes: %v", err) - } -} - -func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) { - if len(selectedByManagement) == 0 { - return - } - - err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs) - if err != nil { - log.Warnf("Failed to select exit nodes: %v", err) - } -} - -func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) { - log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected", - len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected)) -} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index d2f02526c..2f13c2134 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -190,15 +190,14 @@ func TestManagerUpdateRoutes(t *testing.T) { name: "No Small Client Route Should Be Added", inputRoutes: []*route.Route{ { - ID: "a", - NetID: "routeA", - Peer: remotePeerKey1, - Network: netip.MustParsePrefix("0.0.0.0/0"), - NetworkType: route.IPv4Network, - Metric: 9999, - Masquerade: false, - Enabled: true, - SkipAutoApply: false, + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("0.0.0.0/0"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, }, }, inputSerial: 1, diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index 7cb8dae93..a375ce832 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -12,11 +12,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 26a548634..128afa2a5 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -3,6 +3,7 @@ package systemops import ( + "context" "errors" "fmt" "net" @@ -21,7 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - "github.com/netbirdio/netbird/client/net/hooks" + nbnet "github.com/netbirdio/netbird/util/net" ) const localSubnetsCacheTTL = 15 * time.Minute @@ -95,9 +96,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } - hooks.RemoveWriteHooks() - hooks.RemoveCloseHooks() - hooks.RemoveAddressRemoveHooks() + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() if err := r.refCounter.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) @@ -289,7 +290,12 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { - beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } @@ -298,7 +304,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } - afterHook := func(connID hooks.ConnectionID) error { + afterHook := func(connID nbnet.ConnectionID) error { if err := r.refCounter.DecrementWithID(string(connID)); err != nil { return fmt.Errorf("remove route reference: %w", err) } @@ -311,20 +317,36 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M var merr *multierror.Error for _, ip := range initAddresses { - prefix, err := util.GetPrefixFromIP(ip) - if err != nil { - merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err)) - continue - } - if err := beforeHook("init", prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err)) + if err := beforeHook("init", ip); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) } } - hooks.AddWriteHook(beforeHook) - hooks.AddCloseHook(afterHook) + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } - hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error { + var merr *multierror.Error + for _, ip := range resolvedIPs { + merr = multierror.Append(merr, beforeHook(connID, ip.IP)) + } + return nberrors.FormatErrorOrNil(merr) + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.Decrement(prefix); err != nil { return fmt.Errorf("remove route reference: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..c1c1182bc 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -22,7 +22,6 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/vars" - nbnet "github.com/netbirdio/netbird/client/net" ) type dialer interface { @@ -144,11 +143,10 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - advancedRouting := nbnet.AdvancedRouting() - err := r.SetupRouting(nil, nil, advancedRouting) + err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) + assert.NoError(t, r.CleanupRouting(nil)) }) intf, err := net.InterfaceByName(wgInterface.Name()) @@ -343,11 +341,10 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - advancedRouting := nbnet.AdvancedRouting() - err := r.SetupRouting(nil, nil, advancedRouting) + err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) + assert.NoError(t, r.CleanupRouting(nil)) }) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -487,11 +484,10 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - advancedRouting := nbnet.AdvancedRouting() - err := r.SetupRouting(nil, nil, advancedRouting) + err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 99a363371..10356eae0 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -12,14 +12,14 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bd10f131f..c0cef94ba 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // IPRule contains IP rule information for debugging @@ -94,15 +94,15 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) { - if !advancedRouting { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { + if !nbnet.AdvancedRouting() { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } defer func() { if err != nil { - if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { - if !advancedRouting { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + if !nbnet.AdvancedRouting() { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..f165f7779 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -20,11 +20,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index 959c697e4..ad37f611f 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type PacketExpectation struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 7bce6af80..36e714ec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -8,7 +8,6 @@ import ( "net/netip" "os" "runtime/debug" - "sort" "strconv" "sync" "syscall" @@ -20,16 +19,9 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" ) -func init() { - nbnet.GetBestInterfaceFunc = GetBestInterface -} - -const ( - InfiniteLifetime = 0xffffffff -) +const InfiniteLifetime = 0xffffffff type RouteUpdateType int @@ -85,14 +77,6 @@ type MIB_IPFORWARD_TABLE2 struct { Table [1]MIB_IPFORWARD_ROW2 // Flexible array member } -// candidateRoute represents a potential route for selection during route lookup -type candidateRoute struct { - interfaceIndex uint32 - prefixLength uint8 - routeMetric uint32 - interfaceMetric int -} - // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -193,20 +177,11 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { - if advancedRouting { - return nil - } - - log.Infof("Using legacy routing setup with ref counters") +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { - if advancedRouting { - return nil - } - +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { return r.cleanupRefCounter(stateManager) } @@ -361,7 +336,7 @@ func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error { if e1 != 0 { return fmt.Errorf("CreateIpForwardEntry2: %w", e1) } - return fmt.Errorf("CreateIpForwardEntry2: code %d", windows.NTStatus(r1)) + return fmt.Errorf("CreateIpForwardEntry2: code %d", r1) } return nil } @@ -660,7 +635,10 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { if table != nil { - _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) + ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) + if ret != 0 { + log.Warnf("FreeMibTable failed with return code: %d", ret) + } } } @@ -674,7 +652,8 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { entryPtr := basePtr + uintptr(i)*entrySize entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - if detailed := buildWindowsDetailedRoute(entry); detailed != nil { + detailed := buildWindowsDetailedRoute(entry) + if detailed != nil { detailedRoutes = append(detailedRoutes, *detailed) } } @@ -823,46 +802,6 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { return ip } -// parseCandidatesFromTable extracts all matching candidate routes from the routing table -func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute { - var candidates []candidateRoute - entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) - basePtr := uintptr(unsafe.Pointer(&table.Table[0])) - - for i := uint32(0); i < table.NumEntries; i++ { - entryPtr := basePtr + uintptr(i)*entrySize - entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - - if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil { - candidates = append(candidates, *candidate) - } - } - - return candidates -} - -// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry -// Returns nil if the route doesn't match the destination or should be skipped -func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute { - if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex { - return nil - } - - destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) - if !destPrefix.IsValid() || !destPrefix.Contains(dest) { - return nil - } - - interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family) - - return &candidateRoute{ - interfaceIndex: entry.InterfaceIndex, - prefixLength: entry.DestinationPrefix.PrefixLength, - routeMetric: entry.Metric, - interfaceMetric: interfaceMetric, - } -} - // getInterfaceMetric retrieves the interface metric for a given interface and address family func getInterfaceMetric(interfaceIndex uint32, family int16) int { if interfaceIndex == 0 { @@ -882,76 +821,6 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int { return int(ipInterfaceRow.Metric) } -// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric -func sortRouteCandidates(candidates []candidateRoute) { - sort.Slice(candidates, func(i, j int) bool { - if candidates[i].prefixLength != candidates[j].prefixLength { - return candidates[i].prefixLength > candidates[j].prefixLength - } - if candidates[i].routeMetric != candidates[j].routeMetric { - return candidates[i].routeMetric < candidates[j].routeMetric - } - return candidates[i].interfaceMetric < candidates[j].interfaceMetric - }) -} - -// GetBestInterface finds the best interface for reaching a destination, -// excluding the VPN interface to avoid routing loops. -// -// Route selection priority: -// 1. Longest prefix match (most specific route) -// 2. Lowest route metric -// 3. Lowest interface metric -func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { - var skipInterfaceIndex int - if vpnIntf != "" { - if iface, err := net.InterfaceByName(vpnIntf); err == nil { - skipInterfaceIndex = iface.Index - } else { - // not critical, if we cannot get ahold of the interface then we won't need to skip it - log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) - } - } - - table, err := getWindowsRoutingTable() - if err != nil { - return nil, fmt.Errorf("get routing table: %w", err) - } - defer freeWindowsRoutingTable(table) - - candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex) - - if len(candidates) == 0 { - return nil, fmt.Errorf("no route to %s", dest) - } - - // Sort routes: prefix length -> route metric -> interface metric - sortRouteCandidates(candidates) - - for _, candidate := range candidates { - iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex)) - if err != nil { - log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err) - continue - } - - if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() { - continue - } - - if iface.Flags&net.FlagUp == 0 { - log.Debugf("interface %s is down, trying next route", iface.Name) - continue - } - - log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d", - dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric) - return iface, nil - } - - return nil, fmt.Errorf("no usable interface found for %s", dest) -} - // formatRouteAge formats the route age in seconds to a human-readable string func formatRouteAge(ageSeconds uint32) string { if ageSeconds == 0 { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 3561adec4..523bd0b0d 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) var ( diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go index 57ea32f69..ac5a48e37 100644 --- a/client/internal/routemanager/util/ip.go +++ b/client/internal/routemanager/util/ip.go @@ -12,8 +12,18 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { if !ok { return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) } - addr = addr.Unmap() - prefix := netip.PrefixFrom(addr, addr.BitLen()) + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) return prefix, nil } diff --git a/client/internal/routemanager/vars/vars.go b/client/internal/routemanager/vars/vars.go index ac11dec8c..4aa986d2f 100644 --- a/client/internal/routemanager/vars/vars.go +++ b/client/internal/routemanager/vars/vars.go @@ -13,6 +13,4 @@ var ( Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0) Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) - - ExitNodeCIDR = "0.0.0.0/0" ) diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index e4a78599e..8ebdc63e5 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,27 +9,19 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) -const ( - exitNodeCIDR = "0.0.0.0/0" -) - type RouteSelector struct { mu sync.RWMutex deselectedRoutes map[route.NetID]struct{} - selectedRoutes map[route.NetID]struct{} deselectAll bool } func NewRouteSelector() *RouteSelector { return &RouteSelector{ deselectedRoutes: map[route.NetID]struct{}{}, - selectedRoutes: map[route.NetID]struct{}{}, deselectAll: false, } } @@ -40,14 +32,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al defer rs.mu.Unlock() if !appendRoute || rs.deselectAll { - if rs.deselectedRoutes == nil { - rs.deselectedRoutes = map[route.NetID]struct{}{} - } - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) for _, r := range allRoutes { rs.deselectedRoutes[r] = struct{}{} } @@ -60,7 +45,6 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al continue } delete(rs.deselectedRoutes, route) - rs.selectedRoutes[route] = struct{}{} } rs.deselectAll = false @@ -74,14 +58,7 @@ func (rs *RouteSelector) SelectAllRoutes() { defer rs.mu.Unlock() rs.deselectAll = false - if rs.deselectedRoutes == nil { - rs.deselectedRoutes = map[route.NetID]struct{}{} - } - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) } // DeselectRoutes removes specific routes from the selection. @@ -100,7 +77,6 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. continue } rs.deselectedRoutes[route] = struct{}{} - delete(rs.selectedRoutes, route) } return errors.FormatErrorOrNil(err) @@ -112,14 +88,7 @@ func (rs *RouteSelector) DeselectAllRoutes() { defer rs.mu.Unlock() rs.deselectAll = true - if rs.deselectedRoutes == nil { - rs.deselectedRoutes = map[route.NetID]struct{}{} - } - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } maps.Clear(rs.deselectedRoutes) - maps.Clear(rs.selectedRoutes) } // IsSelected checks if a specific route is selected. @@ -128,14 +97,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { - log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] - isSelected := !deselected - log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) - return isSelected + return !deselected } // FilterSelected removes unselected routes from the provided map. @@ -158,98 +124,15 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { return filtered } -// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route -func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool { - rs.mu.RLock() - defer rs.mu.RUnlock() - - _, selected := rs.selectedRoutes[routeID] - _, deselected := rs.deselectedRoutes[routeID] - return selected || deselected -} - -func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap { - rs.mu.RLock() - defer rs.mu.RUnlock() - - if rs.deselectAll { - return route.HAMap{} - } - - filtered := make(route.HAMap, len(routes)) - for id, rt := range routes { - netID := id.NetID() - if rs.isDeselected(netID) { - continue - } - - if !isExitNode(rt) { - filtered[id] = rt - continue - } - - rs.applyExitNodeFilter(id, netID, rt, filtered) - } - - return filtered -} - -func (rs *RouteSelector) isDeselected(netID route.NetID) bool { - _, deselected := rs.deselectedRoutes[netID] - return deselected || rs.deselectAll -} - -func isExitNode(rt []*route.Route) bool { - return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR -} - -func (rs *RouteSelector) applyExitNodeFilter( - id route.HAUniqueID, - netID route.NetID, - rt []*route.Route, - out route.HAMap, -) { - - if rs.hasUserSelections() { - // user made explicit selects/deselects - if rs.IsSelected(netID) { - out[id] = rt - } - return - } - - // no explicit selections: only include routes marked !SkipAutoApply (=AutoApply) - sel := collectSelected(rt) - if len(sel) > 0 { - out[id] = sel - } -} - -func (rs *RouteSelector) hasUserSelections() bool { - return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0 -} - -func collectSelected(rt []*route.Route) []*route.Route { - var sel []*route.Route - for _, r := range rt { - if !r.SkipAutoApply { - sel = append(sel, r) - } - } - return sel -} - // MarshalJSON implements the json.Marshaler interface func (rs *RouteSelector) MarshalJSON() ([]byte, error) { rs.mu.RLock() defer rs.mu.RUnlock() return json.Marshal(struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` DeselectAll bool `json:"deselect_all"` }{ - SelectedRoutes: rs.selectedRoutes, DeselectedRoutes: rs.deselectedRoutes, DeselectAll: rs.deselectAll, }) @@ -264,13 +147,11 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { // Check for null or empty JSON if len(data) == 0 || string(data) == "null" { rs.deselectedRoutes = map[route.NetID]struct{}{} - rs.selectedRoutes = map[route.NetID]struct{}{} rs.deselectAll = false return nil } var temp struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` DeselectAll bool `json:"deselect_all"` } @@ -279,16 +160,12 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { return err } - rs.selectedRoutes = temp.SelectedRoutes rs.deselectedRoutes = temp.DeselectedRoutes rs.deselectAll = temp.DeselectAll if rs.deselectedRoutes == nil { rs.deselectedRoutes = map[route.NetID]struct{}{} } - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } return nil } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 5faea2456..cfa723246 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -1,7 +1,6 @@ package routeselector_test import ( - "net/netip" "slices" "testing" @@ -274,62 +273,6 @@ func TestRouteSelector_FilterSelected(t *testing.T) { }, filtered) } -func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) { - rs := routeselector.NewRouteSelector() - - // Create test routes - exitNode1 := &route.Route{ - ID: "route1", - NetID: "net1", - Network: netip.MustParsePrefix("0.0.0.0/0"), - Peer: "peer1", - SkipAutoApply: false, - } - exitNode2 := &route.Route{ - ID: "route2", - NetID: "net1", - Network: netip.MustParsePrefix("0.0.0.0/0"), - Peer: "peer2", - SkipAutoApply: true, - } - normalRoute := &route.Route{ - ID: "route3", - NetID: "net2", - Network: netip.MustParsePrefix("192.168.1.0/24"), - Peer: "peer3", - SkipAutoApply: false, - } - - routes := route.HAMap{ - "net1|0.0.0.0/0": {exitNode1, exitNode2}, - "net2|192.168.1.0/24": {normalRoute}, - } - - // Test filtering - filtered := rs.FilterSelectedExitNodes(routes) - - // Should only include selected exit nodes and all normal routes - assert.Len(t, filtered, 2) - assert.Len(t, filtered["net1|0.0.0.0/0"], 1) // Only the selected exit node - assert.Equal(t, exitNode1.ID, filtered["net1|0.0.0.0/0"][0].ID) - assert.Len(t, filtered["net2|192.168.1.0/24"], 1) // Normal route should be included - assert.Equal(t, normalRoute.ID, filtered["net2|192.168.1.0/24"][0].ID) - - // Test with deselected routes - err := rs.DeselectRoutes([]route.NetID{"net1"}, []route.NetID{"net1", "net2"}) - assert.NoError(t, err) - filtered = rs.FilterSelectedExitNodes(routes) - assert.Len(t, filtered, 1) // Only normal route should remain - assert.Len(t, filtered["net2|192.168.1.0/24"], 1) - assert.Equal(t, normalRoute.ID, filtered["net2|192.168.1.0/24"][0].ID) - - // Test with deselect all - rs = routeselector.NewRouteSelector() - rs.DeselectAllRoutes() - filtered = rs.FilterSelectedExitNodes(routes) - assert.Len(t, filtered, 0) // No routes should be selected -} - func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialRoutes := []route.NetID{"route1", "route2", "route3"} newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index 8961eaa69..e80adb42b 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -5,7 +5,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // Dial connects to the address on the named network. diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index d3be1896f..9ce0a5556 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -6,7 +6,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // ListenPacket listens for incoming packets on the given network and address. diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 4b031c05c..aa9fdd045 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -9,7 +9,6 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" ) @@ -33,15 +32,9 @@ type Net struct { // NewNetWithDiscover creates a new StdNet instance. func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { n := &Net{ + iFaceDiscover: newMobileIFaceDiscover(iFaceDiscover), interfaceFilter: InterfaceFilter(disallowList), } - // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client - // so in android cli use pionDiscover - if netstack.IsEnabled() { - n.iFaceDiscover = pionDiscover{} - } else { - n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover) - } return n, n.UpdateInterfaces() } diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go deleted file mode 100644 index 78d70c15b..000000000 --- a/client/internal/wg_iface_monitor.go +++ /dev/null @@ -1,98 +0,0 @@ -package internal - -import ( - "context" - "errors" - "fmt" - "net" - "runtime" - "time" - - log "github.com/sirupsen/logrus" -) - -// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine -// if the interface is deleted externally while the engine is running. -type WGIfaceMonitor struct { - done chan struct{} -} - -// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. -func NewWGIfaceMonitor() *WGIfaceMonitor { - return &WGIfaceMonitor{ - done: make(chan struct{}), - } -} - -// Start begins monitoring the WireGuard interface. -// It relies on the provided context cancellation to stop. -func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { - defer close(m.done) - - // Skip on mobile platforms as they handle interface lifecycle differently - if runtime.GOOS == "android" || runtime.GOOS == "ios" { - log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) - return false, errors.New("not supported on mobile platforms") - } - - if ifaceName == "" { - log.Debugf("Interface monitor: empty interface name, skipping monitor") - return false, errors.New("empty interface name") - } - - // Get initial interface index to track the specific interface instance - expectedIndex, err := getInterfaceIndex(ifaceName) - if err != nil { - log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) - return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) - } - - log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) - - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - log.Infof("Interface monitor: stopped for %s", ifaceName) - return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) - case <-ticker.C: - currentIndex, err := getInterfaceIndex(ifaceName) - if err != nil { - // Interface was deleted - log.Infof("Interface monitor: %s deleted", ifaceName) - return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) - } - - // Check if interface index changed (interface was recreated) - if currentIndex != expectedIndex { - log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", - ifaceName, expectedIndex, currentIndex) - return true, nil - } - } - } - -} - -// getInterfaceIndex returns the index of a network interface by name. -// Returns an error if the interface is not found. -func getInterfaceIndex(name string) (int, error) { - if name == "" { - return 0, fmt.Errorf("empty interface name") - } - ifi, err := net.InterfaceByName(name) - if err != nil { - // Check if it's specifically a "not found" error - if errors.Is(err, &net.OpError{}) { - // On some systems, this might be a "not found" error - return 0, fmt.Errorf("interface not found: %w", err) - } - return 0, fmt.Errorf("failed to lookup interface: %w", err) - } - if ifi == nil { - return 0, fmt.Errorf("interface not found") - } - return ifi.Index, nil -} diff --git a/client/net/conn.go b/client/net/conn.go deleted file mode 100644 index 918e7f628..000000000 --- a/client/net/conn.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build !ios - -package net - -import ( - "io" - "net" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/net/hooks" -) - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID hooks.ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. -func (c *Conn) Close() error { - return closeConn(c.ID, c.Conn) -} - -// TCPConn wraps net.TCPConn to override its Close method to include hook functionality. -type TCPConn struct { - *net.TCPConn - ID hooks.ConnectionID -} - -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. -func (c *TCPConn) Close() error { - return closeConn(c.ID, c.TCPConn) -} - -// closeConn is a helper function to close connections and execute close hooks. -func closeConn(id hooks.ConnectionID, conn io.Closer) error { - err := conn.Close() - - closeHooks := hooks.GetCloseHooks() - for _, hook := range closeHooks { - if err := hook(id); err != nil { - log.Errorf("Error executing close hook: %v", err) - } - } - - return err -} diff --git a/client/net/dial.go b/client/net/dial.go deleted file mode 100644 index 041a00e5d..000000000 --- a/client/net/dial.go +++ /dev/null @@ -1,82 +0,0 @@ -//go:build !ios - -package net - -import ( - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - switch c := conn.(type) { - case *net.UDPConn: - // Advanced routing: plain connection - return c, nil - case *Conn: - // Legacy routing: wrapped connection preserves close hooks - udpConn, ok := c.Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn) - } - return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil - } - - if err := conn.Close(); err != nil { - log.Errorf("failed to close connection: %v", err) - } - return nil, fmt.Errorf("unexpected connection type: %T", conn) -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - switch c := conn.(type) { - case *net.TCPConn: - // Advanced routing: plain connection - return c, nil - case *Conn: - // Legacy routing: wrapped connection preserves close hooks - tcpConn, ok := c.Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn) - } - return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil - } - - if err := conn.Close(); err != nil { - log.Errorf("failed to close connection: %v", err) - } - return nil, fmt.Errorf("unexpected connection type: %T", conn) -} diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go deleted file mode 100644 index 2e1eb53d8..000000000 --- a/client/net/dialer_dial.go +++ /dev/null @@ -1,87 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - - nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/client/net/hooks" -) - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - log.Debugf("Dialing %s %s", network, address) - - if CustomRoutingDisabled() || AdvancedRouting() { - return d.Dialer.DialContext(ctx, network, address) - } - - connID := hooks.GenerateConnID() - if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error { - if ctx.Err() != nil { - return ctx.Err() - } - - writeHooks := hooks.GetWriteHooks() - if len(writeHooks) == 0 { - return nil - } - - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - - resolver := customResolver - if resolver == nil { - resolver = net.DefaultResolver - } - - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var merr *multierror.Error - for _, ip := range ips { - prefix, err := util.GetPrefixFromIP(ip.IP) - if err != nil { - merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err)) - continue - } - for _, hook := range writeHooks { - if err := hook(connID, prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err)) - } - } - } - - return nberrors.FormatErrorOrNil(merr) -} diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go deleted file mode 100644 index 18ebc6ad1..000000000 --- a/client/net/dialer_init_generic.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux && !windows - -package net - -func (d *Dialer) init() { - // implemented on Linux, Android, and Windows only -} diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go deleted file mode 100644 index 6eefe5b1e..000000000 --- a/client/net/dialer_init_windows.go +++ /dev/null @@ -1,5 +0,0 @@ -package net - -func (d *Dialer) init() { - d.Dialer.Control = applyUnicastIFToSocket -} diff --git a/client/net/env_android.go b/client/net/env_android.go deleted file mode 100644 index 9d89951a1..000000000 --- a/client/net/env_android.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build android - -package net - -// Init initializes the network environment for Android -func Init() { - // No initialization needed on Android -} - -// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. -// Always returns true on Android since we cannot handle routes dynamically. -func AdvancedRouting() bool { - return true -} - -// SetVPNInterfaceName is a no-op on Android -func SetVPNInterfaceName(name string) { - // No-op on Android - not needed for Android VPN service -} - -// GetVPNInterfaceName returns empty string on Android -func GetVPNInterfaceName() string { - return "" -} diff --git a/client/net/env_generic.go b/client/net/env_generic.go deleted file mode 100644 index f467930c3..000000000 --- a/client/net/env_generic.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build !linux && !windows && !android - -package net - -// Init initializes the network environment (no-op on non-Linux/Windows platforms) -func Init() { - // No-op on non-Linux/Windows platforms -} - -// AdvancedRouting returns false on non-Linux/Windows platforms -func AdvancedRouting() bool { - return false -} - -// SetVPNInterfaceName is a no-op on non-Windows platforms -func SetVPNInterfaceName(name string) { - // No-op on non-Windows platforms -} - -// GetVPNInterfaceName returns empty string on non-Windows platforms -func GetVPNInterfaceName() string { - return "" -} diff --git a/client/net/env_windows.go b/client/net/env_windows.go deleted file mode 100644 index 7e8868ba5..000000000 --- a/client/net/env_windows.go +++ /dev/null @@ -1,67 +0,0 @@ -//go:build windows - -package net - -import ( - "os" - "strconv" - "sync" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/iface/netstack" -) - -var ( - vpnInterfaceName string - vpnInitMutex sync.RWMutex - - advancedRoutingSupported bool -) - -func Init() { - advancedRoutingSupported = checkAdvancedRoutingSupport() -} - -func checkAdvancedRoutingSupport() bool { - var err error - var legacyRouting bool - if val := os.Getenv(envUseLegacyRouting); val != "" { - legacyRouting, err = strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) - } - } - - if legacyRouting || netstack.IsEnabled() { - log.Info("advanced routing has been requested to be disabled") - return false - } - - log.Info("system supports advanced routing") - - return true -} - -// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes -func AdvancedRouting() bool { - return advancedRoutingSupported -} - -// GetVPNInterfaceName returns the stored VPN interface name -func GetVPNInterfaceName() string { - vpnInitMutex.RLock() - defer vpnInitMutex.RUnlock() - return vpnInterfaceName -} - -// SetVPNInterfaceName sets the VPN interface name for lazy initialization -func SetVPNInterfaceName(name string) { - vpnInitMutex.Lock() - defer vpnInitMutex.Unlock() - vpnInterfaceName = name - - if name != "" { - log.Infof("VPN interface name set to %s for route exclusion", name) - } -} diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go deleted file mode 100644 index 93d8e18ef..000000000 --- a/client/net/hooks/hooks.go +++ /dev/null @@ -1,93 +0,0 @@ -package hooks - -import ( - "net/netip" - "slices" - "sync" - - "github.com/google/uuid" -) - -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} - -type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error -type CloseHookFunc func(connID ConnectionID) error -type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error - -var ( - hooksMutex sync.RWMutex - - writeHooks []WriteHookFunc - closeHooks []CloseHookFunc - addressRemoveHooks []AddressRemoveHookFunc -) - -// AddWriteHook allows adding a new hook to be executed before writing/dialing. -func AddWriteHook(hook WriteHookFunc) { - hooksMutex.Lock() - defer hooksMutex.Unlock() - writeHooks = append(writeHooks, hook) -} - -// AddCloseHook allows adding a new hook to be executed on connection close. -func AddCloseHook(hook CloseHookFunc) { - hooksMutex.Lock() - defer hooksMutex.Unlock() - closeHooks = append(closeHooks, hook) -} - -// RemoveWriteHooks removes all write hooks. -func RemoveWriteHooks() { - hooksMutex.Lock() - defer hooksMutex.Unlock() - writeHooks = nil -} - -// RemoveCloseHooks removes all close hooks. -func RemoveCloseHooks() { - hooksMutex.Lock() - defer hooksMutex.Unlock() - closeHooks = nil -} - -// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed. -func AddAddressRemoveHook(hook AddressRemoveHookFunc) { - hooksMutex.Lock() - defer hooksMutex.Unlock() - addressRemoveHooks = append(addressRemoveHooks, hook) -} - -// RemoveAddressRemoveHooks removes all listener address hooks. -func RemoveAddressRemoveHooks() { - hooksMutex.Lock() - defer hooksMutex.Unlock() - addressRemoveHooks = nil -} - -// GetWriteHooks returns a copy of the current write hooks. -func GetWriteHooks() []WriteHookFunc { - hooksMutex.RLock() - defer hooksMutex.RUnlock() - return slices.Clone(writeHooks) -} - -// GetCloseHooks returns a copy of the current close hooks. -func GetCloseHooks() []CloseHookFunc { - hooksMutex.RLock() - defer hooksMutex.RUnlock() - return slices.Clone(closeHooks) -} - -// GetAddressRemoveHooks returns a copy of the current listener address remove hooks. -func GetAddressRemoveHooks() []AddressRemoveHookFunc { - hooksMutex.RLock() - defer hooksMutex.RUnlock() - return slices.Clone(addressRemoveHooks) -} diff --git a/client/net/listen.go b/client/net/listen.go deleted file mode 100644 index da7262806..000000000 --- a/client/net/listen.go +++ /dev/null @@ -1,47 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - switch c := conn.(type) { - case *net.UDPConn: - // Advanced routing: plain connection - return c, nil - case *PacketConn: - // Legacy routing: wrapped connection for hooks - udpConn, ok := c.PacketConn.(*net.UDPConn) - if !ok { - if err := c.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn) - } - return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil - } - - if err := conn.Close(); err != nil { - log.Errorf("failed to close connection: %v", err) - } - return nil, fmt.Errorf("unexpected connection type: %T", conn) -} diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go deleted file mode 100644 index 4f8f17ab2..000000000 --- a/client/net/listener_init_generic.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux && !windows - -package net - -func (l *ListenerConfig) init() { - // implemented on Linux, Android, and Windows only -} diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go deleted file mode 100644 index a9399b5f1..000000000 --- a/client/net/listener_init_windows.go +++ /dev/null @@ -1,8 +0,0 @@ -package net - -func (l *ListenerConfig) init() { - // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses. - // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case - // the interface will be selected that serves the default route. - l.ListenConfig.Control = applyUnicastIFToSocket -} diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go deleted file mode 100644 index 0bb5ad67d..000000000 --- a/client/net/listener_listen.go +++ /dev/null @@ -1,153 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - - nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/client/net/hooks" -) - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - if CustomRoutingDisabled() || AdvancedRouting() { - return l.ListenConfig.ListenPacket(ctx, network, address) - } - - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := hooks.GenerateConnID() - - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID hooks.ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { - log.Errorf("Failed to call write hooks: %v", err) - } - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - defer c.seenAddrs.Clear() - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID hooks.ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { - log.Errorf("Failed to call write hooks: %v", err) - } - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - defer c.seenAddrs.Clear() - return closeConn(c.ID, c.UDPConn) -} - -// RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *PacketConn) RemoveAddress(addr string) { - if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { - return - } - - ipStr, _, err := net.SplitHostPort(addr) - if err != nil { - log.Errorf("Error splitting IP address and port: %v", err) - return - } - - ipAddr, err := netip.ParseAddr(ipStr) - if err != nil { - log.Errorf("Error parsing IP address %s: %v", ipStr, err) - return - } - - prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen()) - - addressRemoveHooks := hooks.GetAddressRemoveHooks() - if len(addressRemoveHooks) == 0 { - return - } - - for _, hook := range addressRemoveHooks { - if err := hook(c.ID, prefix); err != nil { - log.Errorf("Error executing listener address remove hook: %v", err) - } - } -} - -// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality -func WrapPacketConn(conn net.PacketConn) net.PacketConn { - if AdvancedRouting() { - // hooks not required for advanced routing - return conn - } - return &PacketConn{ - PacketConn: conn, - ID: hooks.GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - -func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error { - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded { - return nil - } - - writeHooks := hooks.GetWriteHooks() - if len(writeHooks) == 0 { - return nil - } - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr) - } - - prefix, err := util.GetPrefixFromIP(udpAddr.IP) - if err != nil { - return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err) - } - - log.Debugf("Listener resolved IP for %s: %s", addr, prefix) - - var merr *multierror.Error - for _, hook := range writeHooks { - if err := hook(id, prefix); err != nil { - merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err)) - } - } - - return nberrors.FormatErrorOrNil(merr) -} diff --git a/client/net/net_windows.go b/client/net/net_windows.go deleted file mode 100644 index 649d83aaf..000000000 --- a/client/net/net_windows.go +++ /dev/null @@ -1,284 +0,0 @@ -package net - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - "strconv" - "strings" - "syscall" - "time" - "unsafe" - - log "github.com/sirupsen/logrus" - "golang.org/x/sys/windows" -) - -const ( - // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options - IpUnicastIf = 31 - Ipv6UnicastIf = 31 - - // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options - Ipv6V6only = 27 -) - -// GetBestInterfaceFunc is set at runtime to avoid import cycle -var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error) - -// nativeToBigEndian converts a uint32 from native byte order to big-endian -func nativeToBigEndian(v uint32) uint32 { - return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24 -} - -// parseDestinationAddress parses the destination address from various formats -func parseDestinationAddress(network, address string) (netip.Addr, error) { - if address == "" { - if strings.HasSuffix(network, "6") { - return netip.IPv6Unspecified(), nil - } - return netip.IPv4Unspecified(), nil - } - - if addrPort, err := netip.ParseAddrPort(address); err == nil { - return addrPort.Addr(), nil - } - - if dest, err := netip.ParseAddr(address); err == nil { - return dest, nil - } - - host, _, err := net.SplitHostPort(address) - if err != nil { - // No port, treat whole string as host - host = address - } - - if host == "" { - if strings.HasSuffix(network, "6") { - return netip.IPv6Unspecified(), nil - } - return netip.IPv4Unspecified(), nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) - if err != nil || len(ips) == 0 { - return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err) - } - - dest, ok := netip.AddrFromSlice(ips[0].IP) - if !ok { - return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP) - } - - if ips[0].Zone != "" { - dest = dest.WithZone(ips[0].Zone) - } - - return dest, nil -} - -func getInterfaceFromZone(zone string) *net.Interface { - if zone == "" { - return nil - } - - idx, err := strconv.Atoi(zone) - if err != nil { - log.Debugf("invalid zone format for Windows (expected numeric): %s", zone) - return nil - } - - iface, err := net.InterfaceByIndex(idx) - if err != nil { - log.Debugf("failed to get interface by index %d from zone: %v", idx, err) - return nil - } - - return iface -} - -type interfaceSelection struct { - iface4 *net.Interface - iface6 *net.Interface -} - -func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection { - iface := getInterfaceFromZone(zone) - if iface == nil { - return nil - } - - if dest.Is6() { - return &interfaceSelection{iface6: iface} - } - return &interfaceSelection{iface4: iface} -} - -func selectInterfaceForUnspecified() (*interfaceSelection, error) { - if GetBestInterfaceFunc == nil { - return nil, errors.New("GetBestInterfaceFunc not initialized") - } - - var result interfaceSelection - vpnIfaceName := GetVPNInterfaceName() - - if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil { - result.iface4 = iface4 - } else { - log.Debugf("No IPv4 default route found: %v", err) - } - - if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil { - result.iface6 = iface6 - } else { - log.Debugf("No IPv6 default route found: %v", err) - } - - if result.iface4 == nil && result.iface6 == nil { - return nil, errors.New("no default routes found") - } - - return &result, nil -} - -func selectInterface(dest netip.Addr) (*interfaceSelection, error) { - if zone := dest.Zone(); zone != "" { - if selection := selectInterfaceForZone(dest, zone); selection != nil { - return selection, nil - } - } - - if dest.IsUnspecified() { - return selectInterfaceForUnspecified() - } - - if GetBestInterfaceFunc == nil { - return nil, errors.New("GetBestInterfaceFunc not initialized") - } - - iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName()) - if err != nil { - return nil, fmt.Errorf("find route for %s: %w", dest, err) - } - - if dest.Is6() { - return &interfaceSelection{iface6: iface}, nil - } - return &interfaceSelection{iface4: iface}, nil -} - -func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error { - ifaceIndexBE := nativeToBigEndian(uint32(iface.Index)) - if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil { - return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) - } - return nil -} - -func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error { - if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil { - return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) - } - return nil -} - -func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error { - // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.) - // Never generic ones (udp, tcp, ip) - - switch { - case strings.HasSuffix(network, "4"): - // IPv4-only socket (udp4, tcp4, ip4) - return setUnicastIfIPv4(fd, network, selection, address) - - case strings.HasSuffix(network, "6"): - // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only - return setUnicastIfIPv6(fd, network, selection, address) - } - - // Shouldn't reach here based on Go's documented behavior - return fmt.Errorf("unexpected network type: %s", network) -} - -func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error { - if selection.iface4 == nil { - return nil - } - - if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { - return err - } - - log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address) - return nil -} - -func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error { - isDualStack := checkDualStack(fd) - - // For dual-stack sockets, also set the IPv4 option - if isDualStack && selection.iface4 != nil { - if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { - return err - } - log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address) - } - - if selection.iface6 == nil { - return nil - } - - if err := setIPv6UnicastIF(fd, selection.iface6); err != nil { - return err - } - - log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address) - return nil -} - -func checkDualStack(fd uintptr) bool { - var v6Only int - v6OnlyLen := int32(unsafe.Sizeof(v6Only)) - err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen) - return err == nil && v6Only == 0 -} - -// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address -func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error { - if !AdvancedRouting() { - return nil - } - - dest, err := parseDestinationAddress(network, address) - if err != nil { - return err - } - - dest = dest.Unmap() - - if !dest.IsValid() { - return fmt.Errorf("invalid destination address for %s", address) - } - - selection, err := selectInterface(dest) - if err != nil { - return err - } - - var controlErr error - err = c.Control(func(fd uintptr) { - controlErr = setUnicastIf(fd, network, selection, address) - }) - - if err != nil { - return fmt.Errorf("control: %w", err) - } - - return controlErr -} diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 7c9fa021a..2422d2683 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -2,7 +2,7 @@ set -eEuo pipefail : ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() @@ -39,7 +39,7 @@ wait_for_message() { info "not waiting for log line ${message@Q} due to zero timeout." elif test -n "${log_file_path}"; then info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) else info "log file unsupported, sleeping for ${timeout} seconds..." sleep "${timeout}" @@ -81,7 +81,7 @@ wait_for_daemon_startup() { login_if_needed() { local timeout="${1}" - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then info "already logged in, skipping 'netbird up'..." else info "logging in..." diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 841e3c0f7..60835d1cd 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v6.32.1 +// protoc v5.29.3 // source: daemon.proto package proto @@ -278,7 +278,6 @@ type LoginRequest struct { BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"` ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` - Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -531,13 +530,6 @@ func (x *LoginRequest) GetUsername() string { return "" } -func (x *LoginRequest) GetMtu() int64 { - if x != nil && x.Mtu != nil { - return *x.Mtu - } - return 0 -} - type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -794,10 +786,8 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - // the UI do not using this yet, but CLIs could use it to wait until the status is ready - WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -844,13 +834,6 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } -func (x *StatusRequest) GetWaitForReady() bool { - if x != nil && x.WaitForReady != nil { - return *x.WaitForReady - } - return false -} - type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -1051,7 +1034,6 @@ type GetConfigResponse struct { AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"` InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"` WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"` - Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"` DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"` ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"` RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` @@ -1147,13 +1129,6 @@ func (x *GetConfigResponse) GetWireguardPort() int64 { return 0 } -func (x *GetConfigResponse) GetMtu() int64 { - if x != nil { - return x.Mtu - } - return 0 -} - func (x *GetConfigResponse) GetDisableAutoConnect() bool { if x != nil { return x.DisableAutoConnect @@ -3704,7 +3679,6 @@ type SetConfigRequest struct { // cleanDNSLabels clean map list of DNS labels. CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"` DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` - Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -3928,13 +3902,6 @@ func (x *SetConfigRequest) GetDnsRouteInterval() *durationpb.Duration { return nil } -func (x *SetConfigRequest) GetMtu() int64 { - if x != nil && x.Mtu != nil { - return *x.Mtu - } - return 0 -} - type SetConfigResponse struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -4608,7 +4575,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xc3\x0e\n" + + "\fEmptyRequest\"\xa4\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4644,8 +4611,7 @@ const file_daemon_proto_rawDesc = "" + "\x15lazyConnectionEnabled\x18\x1c \x01(\bH\x0fR\x15lazyConnectionEnabled\x88\x01\x01\x12(\n" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + - "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4664,8 +4630,7 @@ const file_daemon_proto_rawDesc = "" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + - "\t_usernameB\x06\n" + - "\x04_mtu\"\xb5\x01\n" + + "\t_username\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + @@ -4682,12 +4647,10 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"\xa1\x01\n" + + "UpResponse\"g\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + - "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + - "\r_waitForReady\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -4698,7 +4661,7 @@ const file_daemon_proto_rawDesc = "" + "\fDownResponse\"P\n" + "\x10GetConfigRequest\x12 \n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + - "\busername\x18\x02 \x01(\tR\busername\"\xb5\x06\n" + + "\busername\x18\x02 \x01(\tR\busername\"\xa3\x06\n" + "\x11GetConfigResponse\x12$\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\n" + @@ -4708,8 +4671,7 @@ const file_daemon_proto_rawDesc = "" + "\fpreSharedKey\x18\x04 \x01(\tR\fpreSharedKey\x12\x1a\n" + "\badminURL\x18\x05 \x01(\tR\badminURL\x12$\n" + "\rinterfaceName\x18\x06 \x01(\tR\rinterfaceName\x12$\n" + - "\rwireguardPort\x18\a \x01(\x03R\rwireguardPort\x12\x10\n" + - "\x03mtu\x18\b \x01(\x03R\x03mtu\x12.\n" + + "\rwireguardPort\x18\a \x01(\x03R\rwireguardPort\x12.\n" + "\x12disableAutoConnect\x18\t \x01(\bR\x12disableAutoConnect\x12*\n" + "\x10serverSSHAllowed\x18\n" + " \x01(\bR\x10serverSSHAllowed\x12*\n" + @@ -4923,7 +4885,7 @@ const file_daemon_proto_rawDesc = "" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\f_profileNameB\v\n" + "\t_username\"\x17\n" + - "\x15SwitchProfileResponse\"\x8e\r\n" + + "\x15SwitchProfileResponse\"\xef\f\n" + "\x10SetConfigRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + @@ -4955,8 +4917,7 @@ const file_daemon_proto_rawDesc = "" + "\n" + "dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" + "\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" + - "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" + + "\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4973,8 +4934,7 @@ const file_daemon_proto_rawDesc = "" + "\x16_disable_notificationsB\x18\n" + "\x16_lazyConnectionEnabledB\x10\n" + "\x0e_block_inboundB\x13\n" + - "\x11_dnsRouteIntervalB\x06\n" + - "\x04_mtu\"\x13\n" + + "\x11_dnsRouteInterval\"\x13\n" + "\x11SetConfigResponse\"Q\n" + "\x11AddProfileRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" + @@ -5242,7 +5202,6 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} - file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5b27b4d98..fa54071ec 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -156,8 +156,6 @@ message LoginRequest { optional string profileName = 30; optional string username = 31; - - optional int64 mtu = 32; } message LoginResponse { @@ -186,8 +184,6 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; - // the UI do not using this yet, but CLIs could use it to wait until the status is ready - optional bool waitForReady = 3; } message StatusResponse{ @@ -227,8 +223,6 @@ message GetConfigResponse { int64 wireguardPort = 7; - int64 mtu = 8; - bool disableAutoConnect = 9; bool serverSSHAllowed = 10; @@ -544,36 +538,36 @@ message SetConfigRequest { string profileName = 2; // managementUrl to authenticate. string managementUrl = 3; - + // adminUrl to manage keys. string adminURL = 4; - + optional bool rosenpassEnabled = 5; - + optional string interfaceName = 6; - + optional int64 wireguardPort = 7; - + optional string optionalPreSharedKey = 8; - + optional bool disableAutoConnect = 9; - + optional bool serverSSHAllowed = 10; - + optional bool rosenpassPermissive = 11; - + optional bool networkMonitor = 12; - + optional bool disable_client_routes = 13; optional bool disable_server_routes = 14; optional bool disable_dns = 15; optional bool disable_firewall = 16; optional bool block_lan_access = 17; - + optional bool disable_notifications = 18; - + optional bool lazyConnectionEnabled = 19; - + optional bool block_inbound = 20; repeated string natExternalIPs = 21; @@ -589,7 +583,6 @@ message SetConfigRequest { optional google.protobuf.Duration dnsRouteInterval = 27; - optional int64 mtu = 28; } message SetConfigResponse{} @@ -640,4 +633,4 @@ message GetFeaturesRequest{} message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; -} +} \ No newline at end of file diff --git a/client/server/server.go b/client/server/server.go index 168b297c6..dd842d099 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -63,9 +63,6 @@ type Server struct { mutex sync.Mutex config *profilemanager.Config proto.UnimplementedDaemonServiceServer - clientRunning bool // protected by mutex - clientRunningChan chan struct{} - clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -104,11 +101,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() - - if s.clientRunning { - return nil - } - state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -178,10 +170,8 @@ func (s *Server) Start() error { return nil } - s.clientRunning = true - s.clientRunningChan = make(chan struct{}) - s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + return nil } @@ -212,22 +202,12 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { - defer func() { - s.mutex.Lock() - s.clientRunning = false - s.mutex.Unlock() - }() - - if s.config.DisableAutoConnect { - if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { - log.Debugf("run client connection exited with error: %v", err) - } - log.Tracef("client connection exited") - return - } - +func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, + runningChan chan struct{}, +) { backOff := getConnectWithBackoff(ctx) + retryStarted := false + go func() { t := time.NewTicker(24 * time.Hour) for { @@ -236,36 +216,89 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil t.Stop() return case <-t.C: - mgmtState := statusRecorder.GetManagementState() - signalState := statusRecorder.GetSignalState() - if mgmtState.Connected && signalState.Connected { - log.Tracef("resetting status") - backOff.Reset() - } else { - log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) + if retryStarted { + + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + retryStarted = false + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) + } } } } }() runOperation := func() error { - err := s.connect(ctx, profileConfig, statusRecorder, runningChan) + log.Tracef("running client connection") + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.logFile) + s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) + + err := s.connectClient.Run(runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) - return err } - log.Tracef("client connection exited gracefully, do not need to retry") - return nil + if config.DisableAutoConnect { + return backoff.Permanent(err) + } + + if !retryStarted { + retryStarted = true + backOff.Reset() + } + + log.Tracef("client connection exited") + return fmt.Errorf("client connection exited") } - if err := backoff.Retry(runOperation, backOff); err != nil { - log.Errorf("operation failed: %v", err) + err := backoff.Retry(runOperation, backOff) + if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { + log.Errorf("received an error when trying to connect: %v", err) + } else { + log.Tracef("retry canceled") + } +} + +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } } - if giveUpChan != nil { - close(giveUpChan) + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) } + return defaultDuration } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -365,11 +398,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.LazyConnectionEnabled = msg.LazyConnectionEnabled config.BlockInbound = msg.BlockInbound - if msg.Mtu != nil { - mtu := uint16(*msg.Mtu) - config.MTU = &mtu - } - if _, err := profilemanager.UpdateConfig(config); err != nil { log.Errorf("failed to update profile config: %v", err) return nil, fmt.Errorf("failed to update profile config: %w", err) @@ -384,7 +412,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(callerCtx) + ctx, cancel := context.WithCancel(s.rootCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -394,11 +422,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(s.rootCtx) + state := internal.CtxGetState(ctx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -454,7 +482,6 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro // nolint ctx = context.WithValue(ctx, system.DeviceNameCtxKey, msg.Hostname) } - s.mutex.Unlock() config, err := s.getConfig(activeProf) @@ -611,20 +638,6 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() - if s.clientRunning { - state := internal.CtxGetState(s.rootCtx) - status, err := state.Status() - if err != nil { - s.mutex.Unlock() - return nil, err - } - if status == internal.StatusNeedsLogin { - s.actCancel() - } - s.mutex.Unlock() - - return s.waitForUp(callerCtx) - } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -640,16 +653,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } - if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but in case it isn't we cancel it. + // it should be nil here, but . if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) + md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -692,31 +705,23 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - s.clientRunning = true - s.clientRunningChan = make(chan struct{}) - s.clientGiveUpChan = make(chan struct{}) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) - - return s.waitForUp(callerCtx) -} - -// todo: handle potential race conditions -func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - select { - case <-s.clientGiveUpChan: - return nil, fmt.Errorf("client gave up to connect") - case <-s.clientRunningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() + runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) + for { + select { + case <-runningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() + } } } @@ -990,47 +995,13 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - s.mutex.Lock() - clientRunning := s.clientRunning - s.mutex.Unlock() - - if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { - state := internal.CtxGetState(s.rootCtx) - status, err := state.Status() - if err != nil { - return nil, err - } - - if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { - s.actCancel() - } - - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - loop: - for { - select { - case <-s.clientGiveUpChan: - ticker.Stop() - break loop - case <-s.clientRunningChan: - ticker.Stop() - break loop - case <-ticker.C: - status, err := state.Status() - if err != nil { - continue - } - if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { - s.actCancel() - } - continue - case <-ctx.Done(): - return nil, ctx.Err() - } - } + if ctx.Err() != nil { + return nil, ctx.Err() } + s.mutex.Lock() + defer s.mutex.Unlock() + status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { return nil, err @@ -1132,7 +1103,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p AdminURL: adminURL.String(), InterfaceName: cfg.WgIface, WireguardPort: int64(cfg.WgPort), - Mtu: int64(cfg.MTU), DisableAutoConnect: cfg.DisableAutoConnect, ServerSSHAllowed: *cfg.ServerSSHAllowed, RosenpassEnabled: cfg.RosenpassEnabled, @@ -1148,6 +1118,45 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p }, nil } +func (s *Server) onSessionExpire() { + if runtime.GOOS != "windows" { + isUIActive := internal.CheckUIApp() + if !isUIActive && s.config.DisableNotifications != nil && !*s.config.DisableNotifications { + if err := sendTerminalNotification(); err != nil { + log.Errorf("send session expire terminal notification: %v", err) + } + } + } +} + +// sendTerminalNotification sends a terminal notification message +// to inform the user that the NetBird connection session has expired. +func sendTerminalNotification() error { + message := "NetBird connection session expired\n\nPlease re-authenticate to connect to the network." + echoCmd := exec.Command("echo", message) + wallCmd := exec.Command("sudo", "wall") + + echoCmdStdout, err := echoCmd.StdoutPipe() + if err != nil { + return err + } + wallCmd.Stdin = echoCmdStdout + + if err := echoCmd.Start(); err != nil { + return err + } + + if err := wallCmd.Start(); err != nil { + return err + } + + if err := echoCmd.Wait(); err != nil { + return err + } + + return wallCmd.Wait() +} + // AddProfile adds a new profile to the daemon. func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { s.mutex.Lock() @@ -1248,16 +1257,6 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) return features, nil } -func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { - log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - if err := s.connectClient.Run(runningChan); err != nil { - return err - } - return nil -} - func (s *Server) checkProfilesDisabled() bool { // Check if the environment variable is set to disable profiles if s.profilesDisabled { @@ -1275,168 +1274,3 @@ func (s *Server) checkUpdateSettingsDisabled() bool { return false } - -func (s *Server) onSessionExpire() { - if runtime.GOOS != "windows" { - isUIActive := internal.CheckUIApp() - if !isUIActive && s.config.DisableNotifications != nil && !*s.config.DisableNotifications { - if err := sendTerminalNotification(); err != nil { - log.Errorf("send session expire terminal notification: %v", err) - } - } - } -} - -// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries -func getConnectWithBackoff(ctx context.Context) backoff.BackOff { - initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) - maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) - maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) - multiplier := defaultRetryMultiplier - - if envValue := os.Getenv(retryMultiplierVar); envValue != "" { - // parse the multiplier from the environment variable string value to float64 - value, err := strconv.ParseFloat(envValue, 64) - if err != nil { - log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) - } else { - multiplier = value - } - } - - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: initialInterval, - RandomizationFactor: 1, - Multiplier: multiplier, - MaxInterval: maxInterval, - MaxElapsedTime: maxElapsedTime, // 14 days - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// parseEnvDuration parses the environment variable and returns the duration -func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { - if envValue := os.Getenv(envVar); envValue != "" { - if duration, err := time.ParseDuration(envValue); err == nil { - return duration - } - log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) - } - return defaultDuration -} - -func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { - pbFullStatus := proto.FullStatus{ - ManagementState: &proto.ManagementState{}, - SignalState: &proto.SignalState{}, - LocalPeerState: &proto.LocalPeerState{}, - Peers: []*proto.PeerState{}, - } - - pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL - pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected - if err := fullStatus.ManagementState.Error; err != nil { - pbFullStatus.ManagementState.Error = err.Error() - } - - pbFullStatus.SignalState.URL = fullStatus.SignalState.URL - pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected - if err := fullStatus.SignalState.Error; err != nil { - pbFullStatus.SignalState.Error = err.Error() - } - - pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP - pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey - pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface - pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN - pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive - pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) - pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) - pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled - - for _, peerState := range fullStatus.Peers { - pbPeerState := &proto.PeerState{ - IP: peerState.IP, - PubKey: peerState.PubKey, - ConnStatus: peerState.ConnStatus.String(), - ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), - Relayed: peerState.Relayed, - LocalIceCandidateType: peerState.LocalIceCandidateType, - RemoteIceCandidateType: peerState.RemoteIceCandidateType, - LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, - RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, - RelayAddress: peerState.RelayServerAddress, - Fqdn: peerState.FQDN, - LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), - BytesRx: peerState.BytesRx, - BytesTx: peerState.BytesTx, - RosenpassEnabled: peerState.RosenpassEnabled, - Networks: maps.Keys(peerState.GetRoutes()), - Latency: durationpb.New(peerState.Latency), - } - pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) - } - - for _, relayState := range fullStatus.Relays { - pbRelayState := &proto.RelayState{ - URI: relayState.URI, - Available: relayState.Err == nil, - } - if err := relayState.Err; err != nil { - pbRelayState.Error = err.Error() - } - pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) - } - - for _, dnsState := range fullStatus.NSGroupStates { - var err string - if dnsState.Error != nil { - err = dnsState.Error.Error() - } - - var servers []string - for _, server := range dnsState.Servers { - servers = append(servers, server.String()) - } - - pbDnsState := &proto.NSGroupState{ - Servers: servers, - Domains: dnsState.Domains, - Enabled: dnsState.Enabled, - Error: err, - } - pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) - } - - return &pbFullStatus -} - -// sendTerminalNotification sends a terminal notification message -// to inform the user that the NetBird connection session has expired. -func sendTerminalNotification() error { - message := "NetBird connection session expired\n\nPlease re-authenticate to connect to the network." - echoCmd := exec.Command("echo", message) - wallCmd := exec.Command("sudo", "wall") - - echoCmdStdout, err := echoCmd.StdoutPipe() - if err != nil { - return err - } - wallCmd.Stdin = echoCmdStdout - - if err := echoCmd.Start(); err != nil { - return err - } - - if err := wallCmd.Start(); err != nil { - return err - } - - if err := echoCmd.Wait(); err != nil { - return err - } - - return wallCmd.Wait() -} diff --git a/client/server/server_test.go b/client/server/server_test.go index 09b0ed499..6f7c4a89a 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,25 +10,25 @@ import ( "time" "github.com/golang/mock/gomock" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -134,12 +134,8 @@ func TestServer_Up(t *testing.T) { profName := "default" - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), - ManagementURL: u.String(), + ConfigPath: filepath.Join(tempDir, profName+".json"), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -157,9 +153,16 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) + err = s.Start() require.NoError(t, err) + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + s.config = &profilemanager.Config{ + ManagementURL: u, + } + upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -168,7 +171,6 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) - log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } @@ -293,20 +295,15 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - permissionsManagerMock := permissions.NewMockManager(ctrl) - peersManager := peers.NewManager(store, permissionsManagerMock) - settingsManagerMock := settings.NewMockManager(ctrl) - - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + permissionsManagerMock := permissions.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) diff --git a/client/system/info.go b/client/system/info.go index a180be4c0..ea3f6063a 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -6,7 +6,6 @@ import ( "net/netip" "strings" - log "github.com/sirupsen/logrus" "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/shared/management/proto" @@ -96,6 +95,14 @@ func (i *Info) SetFlags( i.LazyConnectionEnabled = lazyConnectionEnabled } +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment +} + // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) @@ -173,7 +180,6 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { - log.Debugf("gathering system information with checks: %d", len(checks)) processCheckPaths := make([]string, 0) for _, check := range checks { processCheckPaths = append(processCheckPaths, check.GetFiles()...) @@ -183,11 +189,16 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro if err != nil { return nil, err } - log.Debugf("gathering process check information completed") info := GetInfo(ctx) info.Files = files - log.Debugf("all system information gathered successfully") return info, nil } + +// UpdateStaticInfo asynchronously updates static system and platform information +func UpdateStaticInfo() { + go func() { + _ = updateStaticInfo() + }() +} diff --git a/client/system/info_android.go b/client/system/info_android.go index 78895bfa8..56fe0741d 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -15,11 +15,6 @@ import ( "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update -func UpdateStaticInfoAsync() { - // do nothing -} - // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { kernel := "android" diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index caa344737..f105ada60 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -19,10 +19,6 @@ import ( "github.com/netbirdio/netbird/version" ) -func UpdateStaticInfoAsync() { - go updateStaticInfo() -} - // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { utsname := unix.Utsname{} @@ -45,7 +41,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := getStaticInfo() + si := updateStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index 8e1353151..bed6711de 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -18,11 +18,6 @@ import ( "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update -func UpdateStaticInfoAsync() { - // do nothing -} - // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { out := _getInfo() diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 705c37920..897ec0a35 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -10,11 +10,6 @@ import ( "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update -func UpdateStaticInfoAsync() { - // do nothing -} - // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { diff --git a/client/system/info_linux.go b/client/system/info_linux.go index 6c7a23b95..9bfc82009 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -23,10 +23,6 @@ var ( getSystemInfo = defaultSysInfoImplementation ) -func UpdateStaticInfoAsync() { - go updateStaticInfo() -} - // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { info := _getInfo() @@ -52,7 +48,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := getStaticInfo() + si := updateStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_windows.go b/client/system/info_windows.go index d7f8f30aa..6f05ded20 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -2,51 +2,187 @@ package system import ( "context" + "fmt" "os" "runtime" + "strings" "time" log "github.com/sirupsen/logrus" + "github.com/yusufpapurcu/wmi" + "golang.org/x/sys/windows/registry" "github.com/netbirdio/netbird/version" ) -func UpdateStaticInfoAsync() { - go updateStaticInfo() +type Win32_OperatingSystem struct { + Caption string +} + +type Win32_ComputerSystem struct { + Manufacturer string +} + +type Win32_ComputerSystemProduct struct { + Name string +} + +type Win32_BIOS struct { + SerialNumber string } // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { + osName, osVersion := getOSNameAndVersion() + buildVersion := getBuildVersion() + + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + start := time.Now() - si := getStaticInfo() + si := updateStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ Kernel: "windows", - OSVersion: si.OSVersion, + OSVersion: osVersion, Platform: "unknown", - OS: si.OSName, + OS: osName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), - KernelVersion: si.BuildVersion, + KernelVersion: buildVersion, + NetworkAddresses: addrs, SystemSerialNumber: si.SystemSerialNumber, SystemProductName: si.SystemProductName, SystemManufacturer: si.SystemManufacturer, Environment: si.Environment, } - addrs, err := networkAddresses() - if err != nil { - log.Warnf("failed to discover network addresses: %s", err) - } else { - gio.NetworkAddresses = addrs - } - systemHostname, _ := os.Hostname() gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) + return gio } + +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + +func getOSNameAndVersion() (string, string) { + var dst []Win32_OperatingSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + log.Error(err) + return "Windows", getBuildVersion() + } + + if len(dst) == 0 { + return "Windows", getBuildVersion() + } + + split := strings.Split(dst[0].Caption, " ") + + if len(split) <= 3 { + return "Windows", getBuildVersion() + } + + name := split[1] + version := split[2] + if split[2] == "Server" { + name = fmt.Sprintf("%s %s", split[1], split[2]) + version = split[3] + } + + return name, version +} + +func getBuildVersion() string { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) + if err != nil { + log.Error(err) + return "0.0.0.0" + } + defer func() { + deferErr := k.Close() + if deferErr != nil { + log.Error(deferErr) + } + }() + + major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") + if err != nil { + log.Error(err) + } + minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") + if err != nil { + log.Error(err) + } + build, _, err := k.GetStringValue("CurrentBuildNumber") + if err != nil { + log.Error(err) + } + // Update Build Revision + ubr, _, err := k.GetIntegerValue("UBR") + if err != nil { + log.Error(err) + } + ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) + return ver +} + +func sysNumber() (string, error) { + var dst []Win32_BIOS + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].SerialNumber, nil +} + +func sysProductName() (string, error) { + var dst []Win32_ComputerSystemProduct + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + // `ComputerSystemProduct` could be empty on some virtualized systems + if len(dst) < 1 { + return "unknown", nil + } + return dst[0].Name, nil +} + +func sysManufacturer() (string, error) { + var dst []Win32_ComputerSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].Manufacturer, nil +} diff --git a/client/system/static_info.go b/client/system/static_info.go index 12a2663a1..f178ec932 100644 --- a/client/system/static_info.go +++ b/client/system/static_info.go @@ -3,7 +3,12 @@ package system import ( + "context" "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" ) var ( @@ -11,26 +16,25 @@ var ( once sync.Once ) -// StaticInfo is an object that contains machine information that does not change -type StaticInfo struct { - SystemSerialNumber string - SystemProductName string - SystemManufacturer string - Environment Environment - - // Windows specific fields - OSName string - OSVersion string - BuildVersion string -} - -func updateStaticInfo() { +func updateStaticInfo() StaticInfo { once.Do(func() { - staticInfo = newStaticInfo() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + staticInfo.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() }) -} - -func getStaticInfo() StaticInfo { - updateStaticInfo() return staticInfo } diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go new file mode 100644 index 000000000..faa3e700b --- /dev/null +++ b/client/system/static_info_stub.go @@ -0,0 +1,8 @@ +//go:build android || freebsd || ios + +package system + +// updateStaticInfo returns an empty implementation for unsupported platforms +func updateStaticInfo() StaticInfo { + return StaticInfo{} +} diff --git a/client/system/static_info_update.go b/client/system/static_info_update.go deleted file mode 100644 index af8b1e266..000000000 --- a/client/system/static_info_update.go +++ /dev/null @@ -1,35 +0,0 @@ -//go:build (linux && !android) || (darwin && !ios) - -package system - -import ( - "context" - "sync" - "time" - - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" -) - -func newStaticInfo() StaticInfo { - si := StaticInfo{} - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() - wg.Done() - }() - go func() { - si.Environment.Cloud = detect_cloud.Detect(ctx) - wg.Done() - }() - go func() { - si.Environment.Platform = detect_platform.Detect(ctx) - wg.Done() - }() - wg.Wait() - return si -} diff --git a/client/system/static_info_update_windows.go b/client/system/static_info_update_windows.go deleted file mode 100644 index 5f232c1de..000000000 --- a/client/system/static_info_update_windows.go +++ /dev/null @@ -1,184 +0,0 @@ -package system - -import ( - "context" - "fmt" - "strings" - "sync" - "time" - - log "github.com/sirupsen/logrus" - "github.com/yusufpapurcu/wmi" - "golang.org/x/sys/windows/registry" - - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" -) - -type Win32_OperatingSystem struct { - Caption string -} - -type Win32_ComputerSystem struct { - Manufacturer string -} - -type Win32_ComputerSystemProduct struct { - Name string -} - -type Win32_BIOS struct { - SerialNumber string -} - -func newStaticInfo() StaticInfo { - si := StaticInfo{} - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() - wg.Done() - }() - wg.Add(1) - go func() { - si.Environment.Cloud = detect_cloud.Detect(ctx) - wg.Done() - }() - wg.Add(1) - go func() { - si.Environment.Platform = detect_platform.Detect(ctx) - wg.Done() - }() - wg.Add(1) - go func() { - si.OSName, si.OSVersion = getOSNameAndVersion() - wg.Done() - }() - wg.Add(1) - go func() { - si.BuildVersion = getBuildVersion() - wg.Done() - }() - wg.Wait() - return si -} - -func sysInfo() (serialNumber string, productName string, manufacturer string) { - var err error - serialNumber, err = sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - productName, err = sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err = sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - return serialNumber, productName, manufacturer -} - -func sysNumber() (string, error) { - var dst []Win32_BIOS - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].SerialNumber, nil -} - -func sysProductName() (string, error) { - var dst []Win32_ComputerSystemProduct - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - // `ComputerSystemProduct` could be empty on some virtualized systems - if len(dst) < 1 { - return "unknown", nil - } - return dst[0].Name, nil -} - -func sysManufacturer() (string, error) { - var dst []Win32_ComputerSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].Manufacturer, nil -} - -func getOSNameAndVersion() (string, string) { - var dst []Win32_OperatingSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - log.Error(err) - return "Windows", getBuildVersion() - } - - if len(dst) == 0 { - return "Windows", getBuildVersion() - } - - split := strings.Split(dst[0].Caption, " ") - - if len(split) <= 3 { - return "Windows", getBuildVersion() - } - - name := split[1] - version := split[2] - if split[2] == "Server" { - name = fmt.Sprintf("%s %s", split[1], split[2]) - version = split[3] - } - - return name, version -} - -func getBuildVersion() string { - k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) - if err != nil { - log.Error(err) - return "0.0.0.0" - } - defer func() { - deferErr := k.Close() - if deferErr != nil { - log.Error(deferErr) - } - }() - - major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") - if err != nil { - log.Error(err) - } - minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") - if err != nil { - log.Error(err) - } - build, _, err := k.GetStringValue("CurrentBuildNumber") - if err != nil { - log.Error(err) - } - // Update Build Revision - ubr, _, err := k.GetIntegerValue("UBR") - if err != nil { - log.Error(err) - } - ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) - return ver -} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 25d7380a9..f43606de1 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -257,7 +257,6 @@ type serviceClient struct { iPreSharedKey *widget.Entry iInterfaceName *widget.Entry iInterfacePort *widget.Entry - iMTU *widget.Entry // switch elements for settings form sRosenpassPermissive *widget.Check @@ -273,7 +272,6 @@ type serviceClient struct { RosenpassPermissive bool interfaceName string interfacePort int - mtu uint16 networkMonitor bool disableDNS bool disableClientRoutes bool @@ -415,7 +413,6 @@ func (s *serviceClient) showSettingsUI() { s.iPreSharedKey = widget.NewPasswordEntry() s.iInterfaceName = widget.NewEntry() s.iInterfacePort = widget.NewEntry() - s.iMTU = widget.NewEntry() s.sRosenpassPermissive = widget.NewCheck("Enable Rosenpass permissive mode", nil) @@ -449,7 +446,6 @@ func (s *serviceClient) getSettingsForm() *widget.Form { {Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive}, {Text: "Interface Name", Widget: s.iInterfaceName}, {Text: "Interface Port", Widget: s.iInterfacePort}, - {Text: "MTU", Widget: s.iMTU}, {Text: "Management URL", Widget: s.iMngURL}, {Text: "Pre-shared Key", Widget: s.iPreSharedKey}, {Text: "Log File", Widget: s.iLogFile}, @@ -486,21 +482,6 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - var mtu int64 - mtuText := strings.TrimSpace(s.iMTU.Text) - if mtuText != "" { - var err error - mtu, err = strconv.ParseInt(mtuText, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings) - return - } - if mtu < iface.MinMTU || mtu > iface.MaxMTU { - dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings) - return - } - } - iMngURL := strings.TrimSpace(s.iMngURL.Text) defer s.wSettings.Close() @@ -509,7 +490,6 @@ func (s *serviceClient) getSettingsForm() *widget.Form { if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || s.RosenpassPermissive != s.sRosenpassPermissive.Checked || s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || - s.mtu != uint16(mtu) || s.networkMonitor != s.sNetworkMonitor.Checked || s.disableDNS != s.sDisableDNS.Checked || s.disableClientRoutes != s.sDisableClientRoutes.Checked || @@ -518,7 +498,6 @@ func (s *serviceClient) getSettingsForm() *widget.Form { s.managementURL = iMngURL s.preSharedKey = s.iPreSharedKey.Text - s.mtu = uint16(mtu) currUser, err := user.Current() if err != nil { @@ -537,9 +516,6 @@ func (s *serviceClient) getSettingsForm() *widget.Form { req.RosenpassPermissive = &s.sRosenpassPermissive.Checked req.InterfaceName = &s.iInterfaceName.Text req.WireguardPort = &port - if mtu > 0 { - req.Mtu = &mtu - } req.NetworkMonitor = &s.sNetworkMonitor.Checked req.DisableDns = &s.sDisableDNS.Checked req.DisableClientRoutes = &s.sDisableClientRoutes.Checked @@ -563,28 +539,27 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - go func() { - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("get service status: %v", err) + dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) + return + } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) + log.Errorf("down service: %v", err) + } + + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) return } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) - if err != nil { - log.Errorf("down service: %v", err) - } + } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) - return - } - } - }() } }, OnCancel: func() { @@ -1113,7 +1088,6 @@ func (s *serviceClient) getSrvConfig() { s.RosenpassPermissive = cfg.RosenpassPermissive s.interfaceName = cfg.WgIface s.interfacePort = cfg.WgPort - s.mtu = cfg.MTU s.networkMonitor = *cfg.NetworkMonitor s.disableDNS = cfg.DisableDNS @@ -1126,12 +1100,6 @@ func (s *serviceClient) getSrvConfig() { s.iPreSharedKey.SetText(cfg.PreSharedKey) s.iInterfaceName.SetText(cfg.WgIface) s.iInterfacePort.SetText(strconv.Itoa(cfg.WgPort)) - if cfg.MTU != 0 { - s.iMTU.SetText(strconv.Itoa(int(cfg.MTU))) - } else { - s.iMTU.SetText("") - s.iMTU.SetPlaceHolder(strconv.Itoa(int(iface.DefaultMTU))) - } s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive) if !cfg.RosenpassEnabled { s.sRosenpassPermissive.Disable() @@ -1192,12 +1160,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config { config.WgPort = iface.DefaultWgPort } - if cfg.Mtu != 0 { - config.MTU = uint16(cfg.Mtu) - } else { - config.MTU = iface.DefaultMTU - } - config.DisableAutoConnect = cfg.DisableAutoConnect config.ServerSSHAllowed = &cfg.ServerSSHAllowed config.RosenpassEnabled = cfg.RosenpassEnabled diff --git a/flow/client/client.go b/flow/client/client.go index 603fd6882..949824065 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -20,9 +20,9 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" - nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) type GRPCClient struct { diff --git a/go.mod b/go.mod index c880ace3e..771922063 100644 --- a/go.mod +++ b/go.mod @@ -6,24 +6,26 @@ require ( cunicu.li/go-rosenpass v0.4.0 github.com/cenkalti/backoff/v4 v4.3.0 github.com/cloudflare/circl v1.3.3 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 + github.com/pion/ice/v3 v3.0.2 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.40.0 - golang.org/x/sys v0.34.0 + golang.org/x/crypto v0.37.0 + golang.org/x/sys v0.32.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.73.0 - google.golang.org/protobuf v1.36.8 + google.golang.org/grpc v1.64.1 + google.golang.org/protobuf v1.36.6 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -46,7 +48,6 @@ require ( github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.8 github.com/godbus/dbus/v5 v5.1.0 - github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/gopacket v1.1.19 @@ -62,19 +63,17 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 - github.com/pion/ice/v4 v4.0.0-00010101000000-000000000000 - github.com/pion/logging v0.2.4 + github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 - github.com/pion/stun/v3 v3.0.0 - github.com/pion/transport/v3 v3.0.7 + github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.48.2 @@ -95,18 +94,18 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 - go.opentelemetry.io/otel v1.35.0 + go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.35.0 - go.opentelemetry.io/otel/sdk/metric v1.35.0 + go.opentelemetry.io/otel/metric v1.26.0 + go.opentelemetry.io/otel/sdk/metric v1.26.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.42.0 - golang.org/x/oauth2 v0.28.0 - golang.org/x/sync v0.16.0 - golang.org/x/term v0.33.0 + golang.org/x/net v0.39.0 + golang.org/x/oauth2 v0.27.0 + golang.org/x/sync v0.13.0 + golang.org/x/term v0.31.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -119,7 +118,7 @@ require ( require ( cloud.google.com/go/auth v0.3.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go/compute/metadata v0.3.0 // indirect dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect @@ -215,10 +214,8 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect - github.com/pion/dtls/v3 v3.0.7 // indirect - github.com/pion/mdns/v2 v2.0.7 // indirect + github.com/pion/mdns v0.0.12 // indirect github.com/pion/transport/v2 v2.2.4 // indirect - github.com/pion/turn/v4 v4.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect @@ -235,23 +232,22 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect - go.opentelemetry.io/otel/sdk v1.35.0 // indirect - go.opentelemetry.io/otel/trace v1.35.0 // indirect + go.opentelemetry.io/otel/sdk v1.26.0 // indirect + go.opentelemetry.io/otel/trace v1.26.0 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect - golang.org/x/mod v0.25.0 // indirect - golang.org/x/text v0.27.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.34.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) @@ -264,6 +260,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 +replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index 1b6cdd0a9..b70a6b84c 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= -cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= @@ -250,8 +250,8 @@ github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= -github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -506,10 +506,10 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= -github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= +github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -553,29 +553,21 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= -github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= +github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= -github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= -github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= -github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= +github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= +github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= -github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= -github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= +github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= -github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= -github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= -github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= -github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -603,8 +595,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0 github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= @@ -697,8 +689,6 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= -github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= -github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -730,28 +720,26 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= -go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= -go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= +go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= -go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= -go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= -go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= -go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= -go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= -go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= -go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= +go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= +go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= +go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= +go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= +go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= +go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= +go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -779,8 +767,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -826,8 +814,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -873,8 +861,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -888,8 +876,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= -golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= +golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -903,8 +891,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -972,8 +960,8 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -981,8 +969,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= -golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -996,8 +984,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1060,8 +1048,8 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.8-0.20211022200916-316ba0b74098/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= -golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1144,11 +1132,10 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= -google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM= -google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= +google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1169,8 +1156,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= -google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= +google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= +google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1185,10 +1172,11 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index cfec1000e..2d7c65cbe 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -328,45 +328,6 @@ delete_auto_service_user() { echo "$PARSED_RESPONSE" } -delete_default_zitadel_admin() { - INSTANCE_URL=$1 - PAT=$2 - - # Search for the default zitadel-admin user - RESPONSE=$( - curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \ - -H "Authorization: Bearer $PAT" \ - -H "Content-Type: application/json" \ - -d '{ - "queries": [ - { - "userNameQuery": { - "userName": "zitadel-admin@", - "method": "TEXT_QUERY_METHOD_STARTS_WITH" - } - } - ] - }' - ) - - DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty') - - if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then - echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID" - - RESPONSE=$( - curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \ - -H "Authorization: Bearer $PAT" \ - -H "Content-Type: application/json" \ - ) - PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"') - handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE" - - else - echo "Default zitadel-admin user not found: $RESPONSE" - fi -} - init_zitadel() { echo -e "\nInitializing Zitadel with NetBird's applications\n" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" @@ -385,9 +346,6 @@ init_zitadel() { echo -n "Waiting for Zitadel to become ready " wait_api "$INSTANCE_URL" "$PAT" - echo "Deleting default zitadel-admin user..." - delete_default_zitadel_admin "$INSTANCE_URL" "$PAT" - # create the zitadel project echo "Creating new zitadel project" PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index f7fa4a9d0..23fd760aa 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -17,7 +17,7 @@ upstream signal { server 127.0.0.1:10000; } upstream management { - # insert the grpc+http port of your management container here + # insert the grpc+http port of your signal container here server 127.0.0.1:8012; } @@ -75,4 +75,4 @@ server { ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem; -} +} \ No newline at end of file diff --git a/management/README.md b/management/README.md index c70285d43..1122a9e76 100644 --- a/management/README.md +++ b/management/README.md @@ -111,6 +111,3 @@ Generate gRpc code: #!/bin/bash protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=. ``` - - - diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index af860920f..071247938 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -26,11 +26,7 @@ func (s *BaseServer) JobManager() *server.JobManager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator( - context.Background(), - s.PeersManager(), - s.SettingsManager(), - s.EventStore()) + integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 016706b3b..77f899aa4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -105,8 +105,6 @@ type DefaultAccountManager struct { accountUpdateLocks sync.Map updateAccountPeersBufferInterval atomic.Int64 - loginFilter *loginFilter - disableDefaultPolicy bool } @@ -216,7 +214,6 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, - loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, } @@ -303,6 +300,9 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // User that performs the update has to belong to the account. // Returns an updated Settings func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) if err != nil { return nil, fmt.Errorf("failed to validate user permissions: %w", err) @@ -348,17 +348,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } } - if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil { - return err - } - if updateAccountPeers || groupsUpdated { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } } - return nil + return transaction.SaveAccountSettings(ctx, accountID, newSettings) }) if err != nil { return nil, err @@ -502,6 +498,8 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID) //nolint ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource)) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { @@ -537,6 +535,9 @@ func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) @@ -677,6 +678,8 @@ func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheSto // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err @@ -1045,6 +1048,9 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) @@ -1137,20 +1143,12 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai } func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID - - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID) - if err != nil { - return "", err - } - - if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired { - newUser.Blocked = true - newUser.PendingApproval = true - } - - err = am.Store.SaveUser(ctx, newUser) + err := am.Store.SaveUser(ctx, newUser) if err != nil { return "", err } @@ -1160,11 +1158,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, return "", err } - if newUser.PendingApproval { - am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true}) - } else { - am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) - } + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) return domainAccountID, nil } @@ -1363,6 +1357,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId) + defer func() { + if unlockAccount != nil { + unlockAccount() + } + }() + var addNewGroups []string var removeOldGroups []string var hasChanges bool @@ -1425,6 +1426,8 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return fmt.Errorf("error incrementing network serial: %w", err) } } + unlockAccount() + unlockAccount = nil return nil }) @@ -1633,16 +1636,17 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { - return am.loginFilter.allowLogin(wgPubKey, metahash) -} - func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { start := time.Now() defer func() { log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start)) }() + accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) + defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() + peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) @@ -1653,18 +1657,22 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - metahash := metaHash(meta, realIP.String()) - am.loginFilter.addLogin(peerPubKey, metahash) - return peer, netMap, postureChecks, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { + accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) + defer accountUnlock() + peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer peerUnlock() + err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } + return nil + } func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { @@ -1673,6 +1681,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } + unlock := am.Store.AcquireReadLockByUID(ctx, accountID) + defer unlock() + + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer unlockPeer() + _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) @@ -1717,9 +1731,7 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err) continue } - if peer.UserID != "" { - peers = append(peers, peer) - } + peers = append(peers, peer) } if len(peers) > 0 { err := am.expireAndUpdatePeers(ctx, accountID, peers) @@ -1815,9 +1827,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, - Extra: &types.ExtraSettings{ - UserApprovalRequired: true, - }, }, Onboarding: types.AccountOnboarding{ OnboardingFlowPending: true, @@ -1924,9 +1933,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, - Extra: &types.ExtraSettings{ - UserApprovalRequired: true, - }, }, } @@ -2112,6 +2118,9 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee } func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) if err != nil { return fmt.Errorf("validate user permissions: %w", err) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 198f08bb8..13154b98c 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -32,8 +32,6 @@ type Manager interface { DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) - RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) @@ -79,7 +77,7 @@ type Manager interface { DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) @@ -128,5 +126,4 @@ type Manager interface { CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) - AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 14fb27f42..2770cfdb0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/prometheus/client_golang/prometheus/push" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -26,7 +25,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -3048,14 +3046,19 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") + minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD - testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark") } - if msPerOp > maxExpected { - b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3118,14 +3121,19 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") + minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD - testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer") } - if msPerOp > maxExpected { - b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3188,44 +3196,24 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") + minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD - testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp > maxExpected { - b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } } -func TestMain(m *testing.M) { - exitCode := m.Run() - - if exitCode == 0 && os.Getenv("CI") == "true" { - runID := os.Getenv("GITHUB_RUN_ID") - storeEngine := os.Getenv("NETBIRD_STORE_ENGINE") - err := push.New("http://localhost:9091", "account_manager_benchmark"). - Collector(testing_tools.BenchmarkDuration). - Grouping("ci_run", runID). - Grouping("store_engine", storeEngine). - Push() - if err != nil { - log.Printf("Failed to push metrics: %v", err) - } else { - time.Sleep(1 * time.Minute) - _ = push.New("http://localhost:9091", "account_manager_benchmark"). - Grouping("ci_run", runID). - Grouping("store_engine", storeEngine). - Delete() - } - } - - os.Exit(exitCode) -} - func Test_GetCreateAccountByPrivateDomain(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -3606,93 +3594,3 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { require.Error(t, err, "should fail with invalid peer ID") }) } - -func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create a domain-based account with user approval enabled - existingAccountID := "existing-account" - account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false) - account.Settings.Extra = &types.ExtraSettings{ - UserApprovalRequired: true, - } - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Set the account as domain primary account - account.IsDomainPrimaryAccount = true - account.DomainCategory = types.PrivateCategory - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Test adding new user to existing account with approval required - newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ - UserId: newUserID, - Domain: "example.com", - DomainCategory: types.PrivateCategory, - } - - acc, err := manager.Store.GetAccount(context.Background(), existingAccountID) - require.NoError(t, err) - require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain") - require.Equal(t, "example.com", acc.Domain, "Account domain should match") - - returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) - require.NoError(t, err) - require.Equal(t, existingAccountID, returnedAccountID) - - // Verify user was created with pending approval - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) - require.NoError(t, err) - assert.True(t, user.Blocked, "User should be blocked when approval is required") - assert.True(t, user.PendingApproval, "User should be pending approval") - assert.Equal(t, existingAccountID, user.AccountID) -} - -func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create a domain-based account without user approval - ownerUserAuth := nbcontext.UserAuth{ - UserId: "owner-user", - Domain: "example.com", - DomainCategory: types.PrivateCategory, - } - existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth) - require.NoError(t, err) - - // Modify the account to disable user approval - account, err := manager.Store.GetAccount(context.Background(), existingAccountID) - require.NoError(t, err) - account.Settings.Extra = &types.ExtraSettings{ - UserApprovalRequired: false, - } - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Test adding new user to existing account without approval required - newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ - UserId: newUserID, - Domain: "example.com", - DomainCategory: types.PrivateCategory, - } - - returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) - require.NoError(t, err) - require.Equal(t, existingAccountID, returnedAccountID) - - // Verify user was created without pending approval - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) - require.NoError(t, err) - assert.False(t, user.Blocked, "User should not be blocked when approval is not required") - assert.False(t, user.PendingApproval, "User should not be pending approval") - assert.Equal(t, existingAccountID, user.AccountID) -} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index f97e35fa9..6b4cab04a 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -177,8 +177,6 @@ const ( AccountNetworkRangeUpdated Activity = 87 PeerIPUpdated Activity = 88 - UserApproved Activity = 89 - UserRejected Activity = 90 JobCreatedByUser Activity = 89 @@ -290,9 +288,6 @@ var activityMap = map[Activity]Code{ PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, JobCreatedByUser: {"Create Job for peer", "peer.job.create"}, - - UserApproved: {"User approved", "user.approve"}, - UserRejected: {"User rejected", "user.reject"}, } // StringCode returns a string code of the activity diff --git a/management/server/auth/jwt/extractor.go b/management/server/auth/jwt/extractor.go index d270d0ff1..fab429125 100644 --- a/management/server/auth/jwt/extractor.go +++ b/management/server/auth/jwt/extractor.go @@ -5,7 +5,7 @@ import ( "net/url" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" nbcontext "github.com/netbirdio/netbird/management/server/context" diff --git a/management/server/auth/jwt/validator.go b/management/server/auth/jwt/validator.go index 239447b96..5b38ca786 100644 --- a/management/server/auth/jwt/validator.go +++ b/management/server/auth/jwt/validator.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" ) @@ -63,10 +63,12 @@ type Validator struct { } var ( - errKeyNotFound = errors.New("unable to find appropriate key") - errTokenEmpty = errors.New("required authorization token not found") - errTokenInvalid = errors.New("token is invalid") - errTokenParsing = errors.New("token could not be parsed") + errKeyNotFound = errors.New("unable to find appropriate key") + errInvalidAudience = errors.New("invalid audience") + errInvalidIssuer = errors.New("invalid issuer") + errTokenEmpty = errors.New("required authorization token not found") + errTokenInvalid = errors.New("token is invalid") + errTokenParsing = errors.New("token could not be parsed") ) func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator { @@ -86,6 +88,24 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { + // Verify 'aud' claim + var checkAud bool + for _, audience := range v.audienceList { + checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false) + if checkAud { + break + } + } + if !checkAud { + return token, errInvalidAudience + } + + // Verify 'issuer' claim + checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false) + if !checkIss { + return token, errInvalidIssuer + } + // If keys are rotated, verify the keys prior to token validation if v.idpSignkeyRefreshEnabled { // If the keys are invalid, retrieve new ones @@ -124,7 +144,7 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { } // ValidateAndParse validates the token and returns the parsed token -func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { +func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // If we get here, the required token is missing @@ -133,13 +153,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To } // Now parse the token - parsedToken, err := jwt.Parse( - token, - v.getKeyFunc(ctx), - jwt.WithAudience(v.audienceList...), - jwt.WithIssuer(v.issuer), - jwt.WithIssuedAt(), - ) + parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx)) // Check if there was an error in parsing... if err != nil { diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index ece9dc321..53d479c90 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -7,7 +7,7 @@ import ( "fmt" "hash/crc32" - "github.com/golang-jwt/jwt/v5" + "github.com/golang-jwt/jwt" "github.com/netbirdio/netbird/base62" nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go index 30a7a7161..bc7066548 100644 --- a/management/server/auth/manager_mock.go +++ b/management/server/auth/manager_mock.go @@ -3,7 +3,7 @@ package auth import ( "context" - "github.com/golang-jwt/jwt/v5" + "github.com/golang-jwt/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index c8015eb37..55fb1e31a 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/management/server/dns.go b/management/server/dns.go index 6b73dbd0e..12aa6e21c 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -20,9 +20,29 @@ import ( // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { + CustomZones sync.Map NameServerGroups sync.Map } +// GetCustomZone retrieves a cached custom zone +func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { + if c == nil { + return nil, false + } + if value, ok := c.CustomZones.Load(key); ok { + return value.(*proto.CustomZone), true + } + return nil, false +} + +// SetCustomZone stores a custom zone in the cache +func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { + if c == nil { + return + } + c.CustomZones.Store(key, value) +} + // GetNameServerGroup retrieves a cached name server group func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { if c == nil { @@ -93,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave) }) if err != nil { return err @@ -192,8 +212,14 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC } for _, zone := range update.CustomZones { - protoZone := convertToProtoCustomZone(zone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) + cacheKey := zone.Domain + if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { + protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) + } else { + protoZone := convertToProtoCustomZone(zone) + cache.SetCustomZone(cacheKey, protoZone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) + } } for _, nsGroup := range update.NameServerGroups { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index a4be99acb..19b89f574 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -474,6 +474,15 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Results should be different for different inputs") } + // Verify that the cache contains elements from both configs + if _, exists := cache.GetCustomZone("example.com"); !exists { + t.Errorf("Cache should contain custom zone for example.com") + } + + if _, exists := cache.GetCustomZone("example.org"); !exists { + t.Errorf("Cache should contain custom zone for example.org") + } + if _, exists := cache.GetNameServerGroup("group1"); !exists { t.Errorf("Cache should contain name server group 'group1'") } diff --git a/management/server/group.go b/management/server/group.go index 487cb6d97..915a87086 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -67,6 +67,9 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, // CreateGroup object of the peers func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) if err != nil { return status.NewPermissionValidationError(err) @@ -93,6 +96,10 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use return err } + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + if err := transaction.CreateGroup(ctx, newGroup); err != nil { return status.Errorf(status.Internal, "failed to create group: %v", err) } @@ -102,8 +109,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err) } } - - return transaction.IncrementNetworkSerial(ctx, accountID) + return nil }) if err != nil { return err @@ -122,6 +128,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use // UpdateGroup object of the peers func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) if err != nil { return status.NewPermissionValidationError(err) @@ -167,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } - if err = transaction.UpdateGroup(ctx, newGroup); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.UpdateGroup(ctx, newGroup) }) if err != nil { return err @@ -202,45 +211,35 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } var eventsToStore []func() + var groupsToSave []*types.Group var updateAccountPeers bool - var globalErr error - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - - if err = transaction.CreateGroup(ctx, newGroup); err != nil { - return err - } - - err = transaction.IncrementNetworkSerial(ctx, accountID) - if err != nil { - return err - } - + groupsToSave = append(groupsToSave, newGroup) groupIDs = append(groupIDs, newGroup.ID) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - - return nil - }) - if err != nil { - log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) - if len(groupIDs) == 1 { - return err - } - globalErr = errors.Join(globalErr, err) - // continue updating other groups } - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + + return transaction.CreateGroups(ctx, accountID, groupsToSave) + }) if err != nil { return err } @@ -253,7 +252,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return globalErr + return nil } // UpdateGroups updates groups in the account. @@ -270,45 +269,35 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } var eventsToStore []func() + var groupsToSave []*types.Group var updateAccountPeers bool - var globalErr error - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - - if err = transaction.UpdateGroup(ctx, newGroup); err != nil { - return err - } - - err = transaction.IncrementNetworkSerial(ctx, accountID) - if err != nil { - return err - } + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - - groupIDs = append(groupIDs, newGroup.ID) - - return nil - }) - if err != nil { - log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) - if len(groups) == 1 { - return err - } - globalErr = errors.Join(globalErr, err) - // continue updating other groups } - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + + return transaction.UpdateGroups(ctx, accountID, groupsToSave) + }) if err != nil { return err } @@ -321,7 +310,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return globalErr + return nil } // prepareGroupEvents prepares a list of event functions to be stored. @@ -393,6 +382,8 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } @@ -432,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete) }) if err != nil { return err @@ -451,6 +442,9 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var updateAccountPeers bool var err error @@ -460,11 +454,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID) }) if err != nil { return err @@ -479,6 +473,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupAddResource appends resource to the group func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *types.Group var updateAccountPeers bool var err error @@ -498,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID return err } - if err = transaction.UpdateGroup(ctx, group); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.UpdateGroup(ctx, group) }) if err != nil { return err @@ -517,6 +514,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var updateAccountPeers bool var err error @@ -526,11 +526,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.RemovePeerFromGroup(ctx, peerID, groupID) }) if err != nil { return err @@ -545,6 +545,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, // GroupDeleteResource removes resource from the group func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *types.Group var updateAccountPeers bool var err error @@ -564,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun return err } - if err = transaction.UpdateGroup(ctx, group); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.UpdateGroup(ctx, group) }) if err != nil { return err @@ -604,6 +607,13 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st newGroup.ID = xid.New().String() } + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } + } + return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 31ff29cbc..1626a0464 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, - newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.SkipAutoApply, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) require.NoError(t, err) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index ce0de5b9c..65e931f18 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,7 +7,6 @@ import ( "io" "net" "net/netip" - "os" "strings" "sync" "time" @@ -41,30 +40,21 @@ import ( internalStatus "github.com/netbirdio/netbird/shared/management/status" ) -const ( - envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" - envBlockPeers = "NB_BLOCK_SAME_PEERS" -) - // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - - peersUpdateManager *PeersUpdateManager - jobManager *JobManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager - peerLocks sync.Map - authManager auth.Manager - - logBlockedPeers bool - blockPeersWithSameConfig bool - integratedPeerValidator integrated_validator.IntegratedValidator + peersUpdateManager *PeersUpdateManager + jobManager *JobManager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager + peerLocks sync.Map + authManager auth.Manager + integratedPeerValidator integrated_validator.IntegratedValidator } // NewServer creates a new Management server @@ -95,24 +85,19 @@ func NewServer( } } - logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" - blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" - return &GRPCServer{ wgKey: key, // peerKey -> event channel - peersUpdateManager: peersUpdateManager, - jobManager: jobManager, - accountManager: accountManager, - settingsManager: settingsManager, - config: config, - secretsManager: secretsManager, - authManager: authManager, - appMetrics: appMetrics, - ephemeralManager: ephemeralManager, - logBlockedPeers: logBlockedPeers, - blockPeersWithSameConfig: blockPeersWithSameConfig, - integratedPeerValidator: integratedPeerValidator, + peersUpdateManager: peersUpdateManager, + jobManager: jobManager, + accountManager: accountManager, + settingsManager: settingsManager, + config: config, + secretsManager: secretsManager, + authManager: authManager, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + integratedPeerValidator: integratedPeerValidator, }, nil } @@ -192,6 +177,9 @@ func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error { // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { reqStart := time.Now() + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequest() + } ctx := srv.Context() @@ -200,27 +188,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if err != nil { return err } - - realIP := getRealIP(ctx) - sRealIP := realIP.String() - peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) - metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() - } - if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) - } - if s.blockPeersWithSameConfig { - return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) - } - } - - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequest() - } - // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) @@ -244,12 +211,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) + realIP := getRealIP(ctx) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + if syncReq.GetMeta() == nil { log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) return mapError(ctx, err) @@ -267,7 +236,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } unlock() @@ -366,7 +335,6 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { - log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -505,9 +473,6 @@ func mapError(ctx context.Context, err error) error { default: } } - if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) { - return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error()) - } log.WithContext(ctx).Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } @@ -599,9 +564,16 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() + defer func() { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) + } + }() + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } realIP := getRealIP(ctx) - sRealIP := realIP.String() - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -609,24 +581,6 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, err } - peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) - metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { - if s.logBlockedPeers { - log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) - } - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() - } - if s.blockPeersWithSameConfig { - return nil, internalStatus.ErrPeerAlreadyLoggedIn - } - } - - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequest() - } - //nolint ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) @@ -637,12 +591,6 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - defer func() { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) - } - }() - if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) @@ -663,7 +611,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), - Meta: peerMeta, + Meta: extractPeerMeta(ctx, loginReq.GetMeta()), UserID: userID, SetupKey: loginReq.GetSetupKey(), ConnectionIP: realIP, @@ -1129,6 +1077,8 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (* return nil, mapError(ctx, err) } + s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) + log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start)) return &proto.Empty{}, nil diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index f1552d0ea..9f2afe29d 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -11,11 +11,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -198,7 +198,6 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.Extra != nil { settings.Extra = &types.ExtraSettings{ PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, - UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, @@ -328,7 +327,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, - UserApprovalRequired: settings.Extra.UserApprovalRequired, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsGroups: settings.Extra.FlowGroups, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 4b9b79fdc..1dad33a6f 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -15,11 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/management/server/types" ) func initAccountsTestData(t *testing.T, account *types.Account) *handler { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 1fd3b7f9a..6c301aa72 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -488,33 +488,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn } return &api.PeerBatch{ - CreatedAt: peer.CreatedAt, - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.GetLastLogin(), - LoginExpired: peer.Status.LoginExpired, - AccessiblePeersCount: accessiblePeersCount, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, + CreatedAt: peer.CreatedAt, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.GetLastLogin(), + LoginExpired: peer.Status.LoginExpired, + AccessiblePeersCount: accessiblePeersCount, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, - Ephemeral: peer.Ephemeral, } } diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 7bb6f2372..7950db1e8 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -8,19 +8,17 @@ import ( "github.com/gorilla/mux" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/route" ) const failedToConvertRoute = "failed to convert route to response: %v" -const exitNodeCIDR = "0.0.0.0/0" - // handler is the routes handler of the account type handler struct { accountManager account.Manager @@ -126,16 +124,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { accessControlGroupIds = *req.AccessControlGroups } - // Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes) - skipAutoApply := false - if req.SkipAutoApply != nil { - skipAutoApply = *req.SkipAutoApply - } else if newPrefix.String() == exitNodeCIDR { - skipAutoApply = false - } - newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply) + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute) if err != nil { util.WriteError(r.Context(), err, w) @@ -152,31 +142,23 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { } func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { - return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId) -} - -func (h *handler) validateRouteUpdate(req api.PutApiRoutesRouteIdJSONRequestBody) error { - return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId) -} - -func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *string, peerGroups *[]string, networkId string) error { - if network != nil && domains != nil { + if req.Network != nil && req.Domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } - if network == nil && domains == nil { + if req.Network == nil && req.Domains == nil { return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided") } - if peer == nil && peerGroups == nil { + if req.Peer == nil && req.PeerGroups == nil { return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided") } - if peer != nil && peerGroups != nil { + if req.Peer != nil && req.PeerGroups != nil { return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided") } - if utf8.RuneCountInString(networkId) > route.MaxNetIDChar || networkId == "" { + if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters", route.MaxNetIDChar) } @@ -213,7 +195,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { return } - if err := h.validateRouteUpdate(req); err != nil { + if err := h.validateRoute(req); err != nil { util.WriteError(r.Context(), err, w) return } @@ -223,24 +205,15 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { peerID = *req.Peer } - // Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes) - skipAutoApply := false - if req.SkipAutoApply != nil { - skipAutoApply = *req.SkipAutoApply - } else if req.Network != nil && *req.Network == exitNodeCIDR { - skipAutoApply = false - } - newRoute := &route.Route{ - ID: route.ID(routeID), - NetID: route.NetID(req.NetworkId), - Masquerade: req.Masquerade, - Metric: req.Metric, - Description: req.Description, - Enabled: req.Enabled, - Groups: req.Groups, - KeepRoute: req.KeepRoute, - SkipAutoApply: skipAutoApply, + ID: route.ID(routeID), + NetID: route.NetID(req.NetworkId), + Masquerade: req.Masquerade, + Metric: req.Metric, + Description: req.Description, + Enabled: req.Enabled, + Groups: req.Groups, + KeepRoute: req.KeepRoute, } if req.Domains != nil { @@ -348,19 +321,18 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { } network := serverRoute.Network.String() route := &api.Route{ - Id: string(serverRoute.ID), - Description: serverRoute.Description, - NetworkId: string(serverRoute.NetID), - Enabled: serverRoute.Enabled, - Peer: &serverRoute.Peer, - Network: &network, - Domains: &domains, - NetworkType: serverRoute.NetworkType.String(), - Masquerade: serverRoute.Masquerade, - Metric: serverRoute.Metric, - Groups: serverRoute.Groups, - KeepRoute: serverRoute.KeepRoute, - SkipAutoApply: &serverRoute.SkipAutoApply, + Id: string(serverRoute.ID), + Description: serverRoute.Description, + NetworkId: string(serverRoute.NetID), + Enabled: serverRoute.Enabled, + Peer: &serverRoute.Peer, + Network: &network, + Domains: &domains, + NetworkType: serverRoute.NetworkType.String(), + Masquerade: serverRoute.Masquerade, + Metric: serverRoute.Metric, + Groups: serverRoute.Groups, + KeepRoute: serverRoute.KeepRoute, } if len(serverRoute.PeerGroups) > 0 { diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 466a7987f..fc0e112f7 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -15,13 +15,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/shared/management/domain" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -62,22 +62,21 @@ func initRoutesTestData() *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { - switch routeID { - case existingRouteID: + if routeID == existingRouteID { return baseExistingRoute, nil - case existingRouteID2: + } + if routeID == existingRouteID2 { route := baseExistingRoute.Copy() route.PeerGroups = []string{existingGroupID} return route, nil - case existingRouteID3: + } else if routeID == existingRouteID3 { route := baseExistingRoute.Copy() route.Domains = domain.List{existingDomain} return route, nil - default: - return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) } + return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool, skipAutoApply bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -104,7 +103,6 @@ func initRoutesTestData() *handler { Groups: groups, KeepRoute: keepRoute, AccessControlGroups: accessControlGroups, - SkipAutoApply: skipAutoApply, }, nil }, SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { @@ -192,20 +190,19 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"],"skip_auto_apply":false}`, existingPeerID, existingGroupID))), + []byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("192.168.0.0/16"), - Peer: &existingPeerID, - NetworkType: route.IPv4NetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, - SkipAutoApply: util.ToPtr(false), + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, }, }, { @@ -213,22 +210,21 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID))), + []byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "domainNet", - Network: util.ToPtr("invalid Prefix"), - KeepRoute: true, - Domains: &[]string{existingDomain}, - Peer: &existingPeerID, - NetworkType: route.DomainNetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, - SkipAutoApply: util.ToPtr(false), + Id: existingRouteID, + Description: "Post", + NetworkId: "domainNet", + Network: util.ToPtr("invalid Prefix"), + KeepRoute: true, + Domains: &[]string{existingDomain}, + Peer: &existingPeerID, + NetworkType: route.DomainNetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, }, }, { @@ -236,7 +232,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBuffer( - []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"],\"skip_auto_apply\":false}", existingPeerID, existingGroupID, existingGroupID))), + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ @@ -250,7 +246,6 @@ func TestRoutesHandlers(t *testing.T) { Enabled: false, Groups: []string{existingGroupID}, AccessControlGroups: &[]string{existingGroupID}, - SkipAutoApply: util.ToPtr(false), }, }, { @@ -341,63 +336,60 @@ func TestRoutesHandlers(t *testing.T) { name: "Network PUT OK", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, - requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"is_selected\":true}", existingPeerID, existingGroupID)), + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("192.168.0.0/16"), - Peer: &existingPeerID, - NetworkType: route.IPv4NetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, - SkipAutoApply: util.ToPtr(false), + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, }, }, { name: "Domains PUT OK", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, - requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID)), + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("invalid Prefix"), - Domains: &[]string{existingDomain}, - Peer: &existingPeerID, - NetworkType: route.DomainNetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, - KeepRoute: true, - SkipAutoApply: util.ToPtr(false), + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("invalid Prefix"), + Domains: &[]string{existingDomain}, + Peer: &existingPeerID, + NetworkType: route.DomainNetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + KeepRoute: true, }, }, { name: "PUT OK when peer_groups provided", requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, - requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"],\"skip_auto_apply\":false}", existingGroupID, existingGroupID)), + requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)), expectedStatus: http.StatusOK, expectedBody: true, expectedRoute: &api.Route{ - Id: existingRouteID, - Description: "Post", - NetworkId: "awesomeNet", - Network: util.ToPtr("192.168.0.0/16"), - Peer: &emptyString, - PeerGroups: &[]string{existingGroupID}, - NetworkType: route.IPv4NetworkString, - Masquerade: false, - Enabled: false, - Groups: []string{existingGroupID}, - SkipAutoApply: util.ToPtr(false), + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: util.ToPtr("192.168.0.0/16"), + Peer: &emptyString, + PeerGroups: &[]string{existingGroupID}, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, }, }, { diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 4e03e5e9b..bcd637db4 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,11 +9,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" nbcontext "github.com/netbirdio/netbird/management/server/context" ) @@ -31,8 +31,6 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") addUsersTokensEndpoint(accountManager, router) } @@ -325,76 +323,17 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { } isCurrent := user.ID == currenUserID - return &api.User{ - Id: user.ID, - Name: user.Name, - Email: user.Email, - Role: user.Role, - AutoGroups: autoGroups, - Status: userStatus, - IsCurrent: &isCurrent, - IsServiceUser: &user.IsServiceUser, - IsBlocked: user.IsBlocked, - LastLogin: &user.LastLogin, - Issued: &user.Issued, - PendingApproval: user.PendingApproval, + Id: user.ID, + Name: user.Name, + Email: user.Email, + Role: user.Role, + AutoGroups: autoGroups, + Status: userStatus, + IsCurrent: &isCurrent, + IsServiceUser: &user.IsServiceUser, + IsBlocked: user.IsBlocked, + LastLogin: &user.LastLogin, + Issued: &user.Issued, } } - -// approveUser is a POST request to approve a user that is pending approval -func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - - vars := mux.Vars(r) - targetUserID := vars["userId"] - if len(targetUserID) == 0 { - util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) - return - } - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - userResponse := toUserResponse(user, userAuth.UserId) - util.WriteJSONObject(r.Context(), w, userResponse) -} - -// rejectUser is a DELETE request to reject a user that is pending approval -func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - - vars := mux.Vars(r) - targetUserID := vars["userId"] - if len(targetUserID) == 0 { - util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) - return - } - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) -} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index e08004218..f7dc81919 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -16,13 +16,13 @@ import ( "github.com/stretchr/testify/require" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" + "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -725,133 +725,3 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri } return modules } - -func TestApproveUserEndpoint(t *testing.T) { - adminUser := &types.User{ - Id: "admin-user", - Role: types.UserRoleAdmin, - AccountID: existingAccountID, - AutoGroups: []string{}, - } - - pendingUser := &types.User{ - Id: "pending-user", - Role: types.UserRoleUser, - AccountID: existingAccountID, - Blocked: true, - PendingApproval: true, - AutoGroups: []string{}, - } - - tt := []struct { - name string - expectedStatus int - expectedBody bool - requestingUser *types.User - }{ - { - name: "approve user as admin should return 200", - expectedStatus: 200, - expectedBody: true, - requestingUser: adminUser, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - am := &mock_server.MockAccountManager{} - am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { - approvedUserInfo := &types.UserInfo{ - ID: pendingUser.Id, - Email: "pending@example.com", - Name: "Pending User", - Role: string(pendingUser.Role), - AutoGroups: []string{}, - IsServiceUser: false, - IsBlocked: false, - PendingApproval: false, - LastLogin: time.Now(), - Issued: types.UserIssuedAPI, - } - return approvedUserInfo, nil - } - - handler := newHandler(am) - router := mux.NewRouter() - router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") - - req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) - require.NoError(t, err) - - userAuth := nbcontext.UserAuth{ - AccountId: existingAccountID, - UserId: tc.requestingUser.Id, - } - ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - - assert.Equal(t, tc.expectedStatus, rr.Code) - - if tc.expectedBody { - var response api.User - err = json.Unmarshal(rr.Body.Bytes(), &response) - require.NoError(t, err) - assert.Equal(t, "pending-user", response.Id) - assert.False(t, response.IsBlocked) - assert.False(t, response.PendingApproval) - } - }) - } -} - -func TestRejectUserEndpoint(t *testing.T) { - adminUser := &types.User{ - Id: "admin-user", - Role: types.UserRoleAdmin, - AccountID: existingAccountID, - AutoGroups: []string{}, - } - - tt := []struct { - name string - expectedStatus int - requestingUser *types.User - }{ - { - name: "reject user as admin should return 200", - expectedStatus: 200, - requestingUser: adminUser, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - am := &mock_server.MockAccountManager{} - am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { - return nil - } - - handler := newHandler(am) - router := mux.NewRouter() - router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE") - - req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) - require.NoError(t, err) - - userAuth := nbcontext.UserAuth{ - AccountId: existingAccountID, - UserId: tc.requestingUser.Id, - } - ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) - req = req.WithContext(ctx) - - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - - assert.Equal(t, tc.expectedStatus, rr.Code) - }) - } -} diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c31..f221e64a9 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -13,9 +13,9 @@ import ( "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/management/server/types" ) type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f5422..2285ed244 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -8,15 +8,16 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/auth" nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/util" ) const ( diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 3fe3fe809..52737e4eb 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -17,9 +17,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) const modulePeers = "peers" @@ -48,7 +47,7 @@ func BenchmarkUpdatePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -66,7 +65,7 @@ func BenchmarkUpdatePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) }) } } @@ -83,7 +82,7 @@ func BenchmarkGetOnePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -93,7 +92,7 @@ func BenchmarkGetOnePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) }) } } @@ -110,7 +109,7 @@ func BenchmarkGetAllPeers(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -120,7 +119,7 @@ func BenchmarkGetAllPeers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) }) } } @@ -137,7 +136,7 @@ func BenchmarkDeletePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -147,7 +146,7 @@ func BenchmarkDeletePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 36b226db0..9404c4ee4 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -17,9 +17,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) // Map to store peers, groups, users, and setupKeys by name @@ -48,7 +47,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -70,7 +69,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) }) } } @@ -87,7 +86,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -110,7 +109,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) }) } } @@ -127,7 +126,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -137,7 +136,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) }) } } @@ -154,7 +153,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -164,7 +163,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) }) } } @@ -181,7 +180,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) b.ResetTimer() @@ -191,7 +190,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 2868a20bd..844b3e7a6 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -18,9 +18,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) const moduleUsers = "users" @@ -47,7 +46,7 @@ func BenchmarkUpdateUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -72,7 +71,7 @@ func BenchmarkUpdateUser(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) }) } } @@ -85,18 +84,18 @@ func BenchmarkGetOneUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) }) } } @@ -111,18 +110,18 @@ func BenchmarkGetAllUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) }) } } @@ -137,7 +136,7 @@ func BenchmarkDeleteUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -148,7 +147,7 @@ func BenchmarkDeleteUsers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 1079de4aa..9f04e3c24 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -15,10 +15,9 @@ import ( "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" - "github.com/netbirdio/netbird/shared/management/http/api" ) func Test_SetupKeys_Create(t *testing.T) { @@ -288,7 +287,7 @@ func Test_SetupKeys_Create(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(user.name+" - "+tc.name, func(t *testing.T) { - apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -573,7 +572,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -752,7 +751,7 @@ func Test_SetupKeys_Get(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) @@ -904,7 +903,7 @@ func Test_SetupKeys_GetAll(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) @@ -1088,7 +1087,7 @@ func Test_SetupKeys_Delete(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go deleted file mode 100644 index 741f03f18..000000000 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ /dev/null @@ -1,137 +0,0 @@ -package channel - -import ( - "context" - "errors" - "net/http" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" - "github.com/stretchr/testify/assert" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/groups" - http2 "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" - "github.com/netbirdio/netbird/management/server/networks" - "github.com/netbirdio/netbird/management/server/networks/resources" - "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/users" -) - -func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) - if err != nil { - t.Fatalf("Failed to create test store: %v", err) - } - t.Cleanup(cleanup) - - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - if err != nil { - t.Fatalf("Failed to create metrics: %v", err) - } - - peersUpdateManager := server.NewPeersUpdateManager(nil) - updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) - done := make(chan struct{}) - if validateUpdate { - go func() { - if expectedPeerUpdate != nil { - peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) - } else { - peerShouldNotReceiveUpdate(t, updMsg) - } - close(done) - }() - } - - geoMock := &geolocation.Mock{} - validatorMock := server.MockIntegratedValidator{} - proxyController := integrations.NewController(store) - userManager := users.NewManager(store) - permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) - if err != nil { - t.Fatalf("Failed to create manager: %v", err) - } - - // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ - ValidateAndParseTokenFunc: mockValidateAndParseToken, - EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, - MarkPATUsedFunc: authManager.MarkPATUsed, - GetPATInfoFunc: authManager.GetPATInfo, - } - - networksManagerMock := networks.NewManagerMock() - resourcesManagerMock := resources.NewManagerMock() - routersManagerMock := routers.NewManagerMock() - groupsManagerMock := groups.NewManagerMock() - peersManager := peers.NewManager(store, permissionsManager) - - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) - if err != nil { - t.Fatalf("Failed to create API handler: %v", err) - } - - return apiHandler, am, done -} - -func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { - t.Helper() - select { - case msg := <-updateMessage: - t.Errorf("Unexpected message received: %+v", msg) - case <-time.After(500 * time.Millisecond): - return - } -} - -func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { - t.Helper() - - select { - case msg := <-updateMessage: - if msg == nil { - t.Errorf("Received nil update message, expected valid message") - } - assert.Equal(t, expected, msg) - case <-time.After(500 * time.Millisecond): - t.Errorf("Timed out waiting for update message") - } -} - -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} - - switch token { - case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": - userAuth.UserId = token - userAuth.AccountId = "testAccountId" - userAuth.Domain = "test.com" - userAuth.DomainCategory = "private" - case "otherUserId": - userAuth.UserId = "otherUserId" - userAuth.AccountId = "otherAccountId" - userAuth.Domain = "other.com" - userAuth.DomainCategory = "private" - case "invalidToken": - return userAuth, nil, errors.New("invalid token") - } - - jwtToken := jwt.New(jwt.SigningMethodHS256) - return userAuth, jwtToken, nil -} diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 2186ecd46..cc3a2a8f6 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -3,6 +3,7 @@ package testing_tools import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -13,12 +14,32 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt" "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/users" + + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" ) @@ -202,11 +223,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta return content, expectedStatus == http.StatusOK } -func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) { +func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { b.Helper() ctx := context.Background() - acc, err := am.GetAccount(ctx, TestAccountId) + account, err := am.GetAccount(ctx, TestAccountId) if err != nil { b.Fatalf("Failed to get account: %v", err) } @@ -222,23 +243,23 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } - acc.Peers[peer.ID] = peer + account.Peers[peer.ID] = peer } // Create users for i := 0; i < users; i++ { user := &types.User{ Id: fmt.Sprintf("olduser-%d", i), - AccountID: acc.Id, + AccountID: account.Id, Role: types.UserRoleUser, } - acc.Users[user.Id] = user + account.Users[user.Id] = user } for i := 0; i < setupKeys; i++ { key := &types.SetupKey{ Id: fmt.Sprintf("oldkey-%d", i), - AccountID: acc.Id, + AccountID: account.Id, AutoGroups: []string{"someGroupID"}, UpdatedAt: time.Now().UTC(), ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), @@ -246,11 +267,11 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se Type: "reusable", UsageLimit: 0, } - acc.SetupKeys[key.Id] = key + account.SetupKeys[key.Id] = key } // Create groups and policies - acc.Policies = make([]*types.Policy, 0, groups) + account.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &types.Group{ @@ -261,7 +282,7 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - acc.Groups[groupID] = group + account.Groups[groupID] = group // Create a policy for this group policy := &types.Policy{ @@ -281,10 +302,10 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se }, }, } - acc.Policies = append(acc.Policies, policy) + account.Policies = append(account.Policies, policy) } - acc.PostureChecks = []*posture.Checks{ + account.PostureChecks = []*posture.Checks{ { ID: "PostureChecksAll", Name: "All", @@ -296,38 +317,52 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se }, } - store := am.GetStore() - - err = store.SaveAccount(context.Background(), acc) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { b.Fatalf("Failed to save account: %v", err) } } -func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { - b.Helper() - - if recorder.Code != http.StatusOK { - b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code) - } - - EvaluateBenchmarkResults(b, testCase, duration, module, operation) - -} - -func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) { +func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { b.Helper() branch := os.Getenv("GIT_BRANCH") - if branch == "" && os.Getenv("CI") == "true" { + if branch == "" { b.Fatalf("environment variable GIT_BRANCH is not set") } + if recorder.Code != http.StatusOK { + b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code) + } + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch) gauge.Set(msPerOp) b.ReportMetric(msPerOp, "ms/op") + +} + +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { + userAuth := nbcontext.UserAuth{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + userAuth.UserId = token + userAuth.AccountId = "testAccountId" + userAuth.Domain = "test.com" + userAuth.DomainCategory = "private" + case "otherUserId": + userAuth.UserId = "otherUserId" + userAuth.AccountId = "otherAccountId" + userAuth.Domain = "other.com" + userAuth.DomainCategory = "private" + case "invalidToken": + return userAuth, nil, errors.New("invalid token") + } + + jwtToken := jwt.New(jwt.SigningMethodHS256) + return userAuth, jwtToken, nil } diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 1eb8434d3..497f1944f 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -4,7 +4,6 @@ import ( "bytes" "compress/gzip" "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -17,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" ) @@ -231,7 +231,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" { return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..f8a0e1210 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -11,11 +11,12 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/telemetry" + + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" ) type mockHTTPClient struct { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 2f87a9bba..00d30d645 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -2,7 +2,6 @@ package idp import ( "context" - "encoding/base64" "fmt" "io" "net/http" @@ -12,6 +11,7 @@ import ( "sync" "time" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" @@ -166,7 +166,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) ( return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 393a39e3e..35b86764d 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -2,7 +2,6 @@ package idp import ( "context" - "encoding/base64" "fmt" "io" "net/http" @@ -11,6 +10,7 @@ import ( "sync" "time" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/telemetry" @@ -168,7 +168,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index c611317ab..07d84058c 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -2,7 +2,6 @@ package idp import ( "context" - "encoding/base64" "fmt" "io" "net/http" @@ -12,6 +11,7 @@ import ( "sync" "time" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/telemetry" @@ -158,7 +158,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 24228346a..343357927 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -2,7 +2,6 @@ package idp import ( "context" - "encoding/base64" "errors" "fmt" "io" @@ -13,6 +12,7 @@ import ( "sync" "time" + "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/telemetry" @@ -253,7 +253,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) } - data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1]) + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) if err != nil { return jwtToken, err } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..509022015 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -46,6 +46,9 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context, groups = []string{} } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { diff --git a/management/server/integrations/port_forwarding/controller.go b/management/server/integrations/port_forwarding/controller.go index f2ce81839..6f062bb12 100644 --- a/management/server/integrations/port_forwarding/controller.go +++ b/management/server/integrations/port_forwarding/controller.go @@ -3,14 +3,12 @@ package port_forwarding import ( "context" - "github.com/netbirdio/netbird/management/server/peer" nbtypes "github.com/netbirdio/netbird/management/server/types" ) type Controller interface { - SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) - GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) - GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) + SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) + GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) IsPeerInIngressPorts(ctx context.Context, accountID, peerID string) (bool, error) } @@ -21,15 +19,11 @@ func NewControllerMock() *ControllerMock { return &ControllerMock{} } -func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) { +func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) { // noop } -func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) { - return make(map[string]*nbtypes.NetworkMap), nil -} - -func (c *ControllerMock) GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) { +func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) { return make(map[string]*nbtypes.NetworkMap), nil } diff --git a/management/server/loginfilter.go b/management/server/loginfilter.go deleted file mode 100644 index 8604af6e2..000000000 --- a/management/server/loginfilter.go +++ /dev/null @@ -1,160 +0,0 @@ -package server - -import ( - "hash/fnv" - "math" - "sync" - "time" - - nbpeer "github.com/netbirdio/netbird/management/server/peer" -) - -const ( - reconnThreshold = 5 * time.Minute - baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit - reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban - metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer -) - -type lfConfig struct { - reconnThreshold time.Duration - baseBlockDuration time.Duration - reconnLimitForBan int - metaChangeLimit int -} - -func initCfg() *lfConfig { - return &lfConfig{ - reconnThreshold: reconnThreshold, - baseBlockDuration: baseBlockDuration, - reconnLimitForBan: reconnLimitForBan, - metaChangeLimit: metaChangeLimit, - } -} - -type loginFilter struct { - mu sync.RWMutex - cfg *lfConfig - logged map[string]*peerState -} - -type peerState struct { - currentHash uint64 - sessionCounter int - sessionStart time.Time - lastSeen time.Time - isBanned bool - banLevel int - banExpiresAt time.Time - metaChangeCounter int - metaChangeWindowStart time.Time -} - -func newLoginFilter() *loginFilter { - return newLoginFilterWithCfg(initCfg()) -} - -func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter { - return &loginFilter{ - logged: make(map[string]*peerState), - cfg: cfg, - } -} - -func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool { - l.mu.RLock() - defer func() { - l.mu.RUnlock() - }() - state, ok := l.logged[wgPubKey] - if !ok { - return true - } - if state.isBanned && time.Now().Before(state.banExpiresAt) { - return false - } - if metaHash != state.currentHash { - if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit { - return false - } - } - return true -} - -func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) { - now := time.Now() - l.mu.Lock() - defer func() { - l.mu.Unlock() - }() - - state, ok := l.logged[wgPubKey] - - if !ok { - l.logged[wgPubKey] = &peerState{ - currentHash: metaHash, - sessionCounter: 1, - sessionStart: now, - lastSeen: now, - metaChangeWindowStart: now, - metaChangeCounter: 1, - } - return - } - - if state.isBanned && now.After(state.banExpiresAt) { - state.isBanned = false - } - - if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) { - state.banLevel = 0 - } - - if metaHash != state.currentHash { - if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) { - state.metaChangeWindowStart = now - state.metaChangeCounter = 1 - } else { - state.metaChangeCounter++ - } - state.currentHash = metaHash - state.sessionCounter = 1 - state.sessionStart = now - state.lastSeen = now - return - } - - state.sessionCounter++ - if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold { - state.isBanned = true - state.banLevel++ - - backoffFactor := math.Pow(2, float64(state.banLevel-1)) - duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor) - state.banExpiresAt = now.Add(duration) - - state.sessionCounter = 0 - state.sessionStart = now - } - state.lastSeen = now -} - -func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 { - h := fnv.New64a() - - h.Write([]byte(meta.WtVersion)) - h.Write([]byte(meta.OSVersion)) - h.Write([]byte(meta.KernelVersion)) - h.Write([]byte(meta.Hostname)) - h.Write([]byte(meta.SystemSerialNumber)) - h.Write([]byte(pubip)) - - macs := uint64(0) - for _, na := range meta.NetworkAddresses { - for _, r := range na.Mac { - macs += uint64(r) - } - } - - return h.Sum64() + macs -} diff --git a/management/server/loginfilter_test.go b/management/server/loginfilter_test.go deleted file mode 100644 index 65782dd9d..000000000 --- a/management/server/loginfilter_test.go +++ /dev/null @@ -1,275 +0,0 @@ -package server - -import ( - "hash/fnv" - "math" - "math/rand" - "strconv" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/suite" - - nbpeer "github.com/netbirdio/netbird/management/server/peer" -) - -func testAdvancedCfg() *lfConfig { - return &lfConfig{ - reconnThreshold: 50 * time.Millisecond, - baseBlockDuration: 100 * time.Millisecond, - reconnLimitForBan: 3, - metaChangeLimit: 2, - } -} - -type LoginFilterTestSuite struct { - suite.Suite - filter *loginFilter -} - -func (s *LoginFilterTestSuite) SetupTest() { - s.filter = newLoginFilterWithCfg(testAdvancedCfg()) -} - -func TestLoginFilterTestSuite(t *testing.T) { - suite.Run(t, new(LoginFilterTestSuite)) -} - -func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() { - pubKey := "PUB_KEY_A" - meta := uint64(1) - - s.True(s.filter.allowLogin(pubKey, meta)) - - s.filter.addLogin(pubKey, meta) - s.Require().Contains(s.filter.logged, pubKey) - s.Equal(1, s.filter.logged[pubKey].sessionCounter) -} - -func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() { - pubKey := "PUB_KEY_A" - meta := uint64(1) - limit := s.filter.cfg.reconnLimitForBan - - for i := 0; i <= limit; i++ { - s.filter.addLogin(pubKey, meta) - } - - s.False(s.filter.allowLogin(pubKey, meta)) - s.Require().Contains(s.filter.logged, pubKey) - s.True(s.filter.logged[pubKey].isBanned) -} - -func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() { - pubKey := "PUB_KEY_A" - meta := uint64(1) - limit := s.filter.cfg.reconnLimitForBan - baseBan := s.filter.cfg.baseBlockDuration - - for i := 0; i <= limit; i++ { - s.filter.addLogin(pubKey, meta) - } - s.Require().Contains(s.filter.logged, pubKey) - s.True(s.filter.logged[pubKey].isBanned) - s.Equal(1, s.filter.logged[pubKey].banLevel) - firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) - s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond)) - - s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second) - s.filter.logged[pubKey].isBanned = false - - for i := 0; i <= limit; i++ { - s.filter.addLogin(pubKey, meta) - } - s.True(s.filter.logged[pubKey].isBanned) - s.Equal(2, s.filter.logged[pubKey].banLevel) - secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) - expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1)) - s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond)) -} - -func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() { - pubKey := "PUB_KEY_A" - meta := uint64(1) - - s.filter.logged[pubKey] = &peerState{ - isBanned: true, - banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)), - } - - s.True(s.filter.allowLogin(pubKey, meta)) - - s.filter.addLogin(pubKey, meta) - s.Require().Contains(s.filter.logged, pubKey) - s.False(s.filter.logged[pubKey].isBanned) -} - -func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() { - pubKey := "PUB_KEY_A" - meta := uint64(1) - - s.filter.logged[pubKey] = &peerState{ - currentHash: meta, - banLevel: 3, - lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration), - } - - s.filter.addLogin(pubKey, meta) - s.Require().Contains(s.filter.logged, pubKey) - s.Equal(0, s.filter.logged[pubKey].banLevel) -} - -func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() { - pubKey := "PUB_KEY_A" - limit := s.filter.cfg.metaChangeLimit - - for i := range limit { - s.filter.addLogin(pubKey, uint64(i+1)) - } - - s.Require().Contains(s.filter.logged, pubKey) - s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter) - - isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1)) - - s.False(isAllowed, "should block new meta hash after limit is reached") -} - -func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() { - pubKey := "PUB_KEY_A" - meta1 := uint64(1) - meta2 := uint64(2) - meta3 := uint64(3) - - s.filter.addLogin(pubKey, meta1) - s.filter.addLogin(pubKey, meta2) - s.Require().Contains(s.filter.logged, pubKey) - s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter) - s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window") - - s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second)) - - s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires") - - s.filter.addLogin(pubKey, meta3) - s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset") -} - -func BenchmarkHashingMethods(b *testing.B) { - meta := nbpeer.PeerSystemMeta{ - WtVersion: "1.25.1", - OSVersion: "Ubuntu 22.04.3 LTS", - KernelVersion: "5.15.0-76-generic", - Hostname: "prod-server-database-01", - SystemSerialNumber: "PC-1234567890", - NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}}, - } - pubip := "8.8.8.8" - - var resultString string - var resultUint uint64 - - b.Run("BuilderString", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - resultString = builderString(meta, pubip) - } - }) - - b.Run("FnvHashToString", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - resultString = fnvHashToString(meta, pubip) - } - }) - - b.Run("FnvHashToUint64 - used", func(b *testing.B) { - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - resultUint = metaHash(meta, pubip) - } - }) - - _ = resultString - _ = resultUint -} - -func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string { - h := fnv.New64a() - - if len(meta.NetworkAddresses) != 0 { - for _, na := range meta.NetworkAddresses { - h.Write([]byte(na.Mac)) - } - } - - h.Write([]byte(meta.WtVersion)) - h.Write([]byte(meta.OSVersion)) - h.Write([]byte(meta.KernelVersion)) - h.Write([]byte(meta.Hostname)) - h.Write([]byte(meta.SystemSerialNumber)) - h.Write([]byte(pubip)) - - return strconv.FormatUint(h.Sum64(), 16) -} - -func builderString(meta nbpeer.PeerSystemMeta, pubip string) string { - mac := getMacAddress(meta.NetworkAddresses) - estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + - len(pubip) + len(mac) + 6 - - var b strings.Builder - b.Grow(estimatedSize) - - b.WriteString(meta.WtVersion) - b.WriteByte('|') - b.WriteString(meta.OSVersion) - b.WriteByte('|') - b.WriteString(meta.KernelVersion) - b.WriteByte('|') - b.WriteString(meta.Hostname) - b.WriteByte('|') - b.WriteString(meta.SystemSerialNumber) - b.WriteByte('|') - b.WriteString(pubip) - - return b.String() -} - -func getMacAddress(nas []nbpeer.NetworkAddress) string { - if len(nas) == 0 { - return "" - } - macs := make([]string, 0, len(nas)) - for _, na := range nas { - macs = append(macs, na.Mac) - } - return strings.Join(macs, "/") -} - -func BenchmarkLoginFilter_ParallelLoad(b *testing.B) { - filter := newLoginFilterWithCfg(testAdvancedCfg()) - numKeys := 100000 - pubKeys := make([]string, numKeys) - for i := range numKeys { - pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i) - } - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - - for pb.Next() { - key := pubKeys[r.Intn(numKeys)] - meta := r.Uint64() - - if filter.allowLogin(key, meta) { - filter.addLogin(key, meta) - } - } - }) -} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index cc79082a6..2a27bf6a7 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -61,7 +61,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -95,8 +95,6 @@ type MockAccountManager struct { LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error - ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) - RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() account.ExternalCacheManager @@ -518,9 +516,9 @@ func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userI } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute, isSelected) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } @@ -631,20 +629,6 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") } -func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { - if am.ApproveUserFunc != nil { - return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID) - } - return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented") -} - -func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { - if am.RejectUserFunc != nil { - return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID) - } - return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented") -} - // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if am.GetNameServerGroupFunc != nil { @@ -993,10 +977,3 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } - -func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { - if am.AllowSyncFunc != nil { - return am.AllowSyncFunc(key, hash) - } - return true -} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f278e1761..1ee8805fc 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -37,6 +37,9 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -70,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return err } - if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.SaveNameServerGroup(ctx, newNSGroup) }) if err != nil { return nil, err @@ -91,6 +94,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if nsGroupToSave == nil { return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } @@ -121,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.SaveNameServerGroup(ctx, nsGroupToSave) }) if err != nil { return err @@ -142,6 +148,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -164,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID) }) if err != nil { return err diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..2bab0e289 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -70,6 +70,9 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network network.ID = xid.New().String() + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + err = m.store.SaveNetwork(ctx, network) if err != nil { return nil, fmt.Errorf("failed to save network: %w", err) @@ -101,6 +104,9 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network return nil, status.NewPermissionDeniedError() } + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) if err != nil { return nil, fmt.Errorf("failed to get network: %w", err) @@ -125,6 +131,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return fmt.Errorf("failed to get network: %w", err) } + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) @@ -158,15 +167,15 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return fmt.Errorf("failed to delete network: %w", err) } - eventsToStore = append(eventsToStore, func() { - m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) - }) - err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) + }) + return nil }) if err != nil { diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 294f51676..d0b29075b 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -108,6 +108,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return nil, fmt.Errorf("failed to create new network resource: %w", err) } + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name) @@ -201,6 +204,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc resource.Domain = domain resource.Prefix = prefix + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) @@ -309,6 +315,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return status.NewPermissionDeniedError() } + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var events []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID) diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..ca99e4fd1 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -88,6 +88,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + var network *networkTypes.Network err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) @@ -154,6 +157,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return nil, status.NewPermissionDeniedError() } + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + var network *networkTypes.Network err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) @@ -197,6 +203,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo return status.NewPermissionDeniedError() } + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var event func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID) diff --git a/management/server/peer.go b/management/server/peer.go index 3b9622d09..16abf2b40 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -192,6 +192,9 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -458,6 +461,9 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID, // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -480,7 +486,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) if err != nil { return err } @@ -494,6 +500,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { + return fmt.Errorf("failed to remove peer from groups: %w", err) + } + eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -543,7 +553,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin } customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return nil, err @@ -615,9 +625,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found") } - if user.PendingApproval { - return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") - } groupsToAdd = user.AutoGroups opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser @@ -728,6 +735,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s newPeer.DNSLabel = freeLabel newPeer.IP = freeIP + unlock := am.Store.AcquireReadLockByUID(ctx, accountID) + defer func() { + if unlock != nil { + unlock() + } + }() + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = transaction.AddPeerToAccount(ctx, newPeer) if err != nil { @@ -779,10 +793,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil }) if err == nil { + unlock() + unlock = nil break } if isUniqueConstraintError(err) { + unlock() + unlock = nil log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err) continue } @@ -941,6 +959,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } } + unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID) + defer unlockAccount() + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey) + defer func() { + if unlockPeer != nil { + unlockPeer() + } + }() + var peer *nbpeer.Peer var updateRemotePeers bool var isRequiresApproval bool @@ -1021,6 +1048,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + unlockPeer() + unlockPeer = nil + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1152,7 +1182,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return nil, nil, nil, err @@ -1325,7 +1355,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return @@ -1464,7 +1494,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers) + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) return @@ -1652,7 +1682,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto } dnsDomain := am.GetDNSDomain(settings) - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 6a6d1c91d..f7140e254 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -24,7 +24,7 @@ type Peer struct { // Meta is a Peer system meta data Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) - Name string `gorm:"index"` + Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // uniqueness index per accountID (check migrations) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b2d5b7e39..92b240285 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -26,7 +26,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" @@ -990,14 +989,19 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") + minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD - testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp > maxExpected { - b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > (maxExpected * 1.1) { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -1605,6 +1609,7 @@ func Test_LoginPeer(t *testing.T) { testCases := []struct { name string setupKey string + wireGuardPubKey string expectExtraDNSLabelsMismatch bool extraDNSLabels []string expectLoginError bool @@ -1968,7 +1973,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, - route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply, + route.Groups, []string{}, true, userID, route.KeepRoute, ) require.NoError(t, err) @@ -2383,186 +2388,3 @@ func TestBufferUpdateAccountPeers(t *testing.T) { assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) } - -func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") - pendingUser.AccountID = account.Id - pendingUser.Blocked = true - pendingUser.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser) - require.NoError(t, err) - - // Try to add peer with pending approval user - key, err := wgtypes.GenerateKey() - require.NoError(t, err) - - peer := &nbpeer.Peer{ - Key: key.PublicKey().String(), - Name: "test-peer", - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-peer", - OS: "linux", - }, - } - - _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) - require.Error(t, err) - assert.Contains(t, err.Error(), "user pending approval cannot add peers") -} - -func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Create regular user (not pending approval) - regularUser := types.NewRegularUser("regular-user") - regularUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), regularUser) - require.NoError(t, err) - - // Try to add peer with regular user - key, err := wgtypes.GenerateKey() - require.NoError(t, err) - - peer := &nbpeer.Peer{ - Key: key.PublicKey().String(), - Name: "test-peer", - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-peer", - OS: "linux", - }, - } - - _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) - require.NoError(t, err, "Regular user should be able to add peers") -} - -func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") - pendingUser.AccountID = account.Id - pendingUser.Blocked = true - pendingUser.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser) - require.NoError(t, err) - - // Create a peer using AddPeer method for the pending user (simulate existing peer) - key, err := wgtypes.GenerateKey() - require.NoError(t, err) - - // Set the user to not be pending initially so peer can be added - pendingUser.Blocked = false - pendingUser.PendingApproval = false - err = manager.Store.SaveUser(context.Background(), pendingUser) - require.NoError(t, err) - - // Add peer using regular flow - newPeer := &nbpeer.Peer{ - Key: key.PublicKey().String(), - Name: "test-peer", - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-peer", - OS: "linux", - WtVersion: "0.28.0", - }, - } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) - require.NoError(t, err) - - // Now set the user back to pending approval after peer was created - pendingUser.Blocked = true - pendingUser.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser) - require.NoError(t, err) - - // Try to login with pending approval user - login := types.PeerLogin{ - WireGuardPubKey: existingPeer.Key, - UserID: pendingUser.Id, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-peer", - OS: "linux", - }, - } - - _, _, _, err = manager.LoginPeer(context.Background(), login) - require.Error(t, err) - e, ok := status.FromError(err) - require.True(t, ok, "error is not a gRPC status error") - assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code") -} - -func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create account - account := newAccountWithId(context.Background(), "test-account", "owner", "", false) - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Create regular user (not pending approval) - regularUser := types.NewRegularUser("regular-user") - regularUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), regularUser) - require.NoError(t, err) - - // Add peer using regular flow for the regular user - key, err := wgtypes.GenerateKey() - require.NoError(t, err) - - newPeer := &nbpeer.Peer{ - Key: key.PublicKey().String(), - Name: "test-peer", - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-peer", - OS: "linux", - WtVersion: "0.28.0", - }, - } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) - require.NoError(t, err) - - // Try to login with regular user - login := types.PeerLogin{ - WireGuardPubKey: existingPeer.Key, - UserID: regularUser.Id, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-peer", - OS: "linux", - }, - } - - _, _, _, err = manager.LoginPeer(context.Background(), login) - require.NoError(t, err, "Regular user should be able to login peers") -} diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go index cb135f4ac..50e36a880 100644 --- a/management/server/peers/manager.go +++ b/management/server/peers/manager.go @@ -18,7 +18,6 @@ type Manager interface { GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) GetPeerAccountID(ctx context.Context, peerID string) (string, error) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) - GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) } type managerImpl struct { @@ -62,7 +61,3 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) } - -func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { - return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) -} diff --git a/management/server/peers/manager_mock.go b/management/server/peers/manager_mock.go index 994f8346b..b247a1752 100644 --- a/management/server/peers/manager_mock.go +++ b/management/server/peers/manager_mock.go @@ -79,18 +79,3 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) } - -// GetPeersByGroupIDs mocks base method. -func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs) - ret0, _ := ret[0].([]*peer.Peer) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs. -func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) -} diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..0ab244243 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -54,14 +54,10 @@ func (m *managerImpl) ValidateUserPermissions( return false, status.NewUserNotFoundError(userID) } - if user.IsBlocked() && !user.PendingApproval { + if user.IsBlocked() { return false, status.NewUserBlockedError() } - if user.IsBlocked() && user.PendingApproval { - return false, status.NewUserPendingApprovalError() - } - if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil { return false, err } diff --git a/management/server/policy.go b/management/server/policy.go index 3adee6397..d5c66e9f8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -32,6 +32,9 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + operation := operations.Create if !create { operation = operations.Update @@ -58,17 +61,17 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + saveFunc := transaction.CreatePolicy if isUpdate { action = activity.PolicyUpdated saveFunc = transaction.SavePolicy } - if err = saveFunc(ctx, policy); err != nil { - return err - } - - return transaction.IncrementNetworkSerial(ctx, accountID) + return saveFunc(ctx, policy) }) if err != nil { return nil, err @@ -85,6 +88,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -107,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.DeletePolicy(ctx, accountID, policyID) }) if err != nil { return err @@ -167,22 +173,10 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } - - // TODO: Refactor to support multiple rules per policy - existingRuleIDs := make(map[string]bool) - for _, rule := range existingPolicy.Rules { - existingRuleIDs[rule.ID] = true - } - - for _, rule := range policy.Rules { - if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) - } - } } else { policy.ID = xid.New().String() policy.AccountID = accountID diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 943f2a970..9414b8065 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -32,6 +32,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + operation := operations.Create if !create { operation = operations.Update @@ -59,19 +62,15 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return err + } + action = activity.PostureCheckUpdated } postureChecks.AccountID = accountID - if err = transaction.SavePostureChecks(ctx, postureChecks); err != nil { - return err - } - - if isUpdate { - return transaction.IncrementNetworkSerial(ctx, accountID) - } - - return nil + return transaction.SavePostureChecks(ctx, postureChecks) }) if err != nil { return nil, err @@ -88,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) if err != nil { return status.NewPermissionValidationError(err) @@ -108,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.DeletePostureChecks(ctx, accountID, postureChecksID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.DeletePostureChecks(ctx, accountID, postureChecksID) }) if err != nil { return err diff --git a/management/server/route.go b/management/server/route.go index 4510426bb..b853d9cd6 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -134,7 +134,10 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -167,7 +170,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri Enabled: enabled, Groups: groups, AccessControlGroups: accessControlGroupIDs, - SkipAutoApply: skipAutoApply, } if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil { @@ -179,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return err } - if err = transaction.SaveRoute(ctx, newRoute); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.SaveRoute(ctx, newRoute) }) if err != nil { return nil, err @@ -200,6 +202,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) if err != nil { return status.NewPermissionValidationError(err) @@ -233,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } routeToSave.AccountID = accountID - if err = transaction.SaveRoute(ctx, routeToSave); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.SaveRoute(ctx, routeToSave) }) if err != nil { return err @@ -254,6 +259,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -276,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } - if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.IncrementNetworkSerial(ctx, accountID) + return transaction.DeleteRoute(ctx, accountID, string(routeID)) }) if err != nil { return fmt.Errorf("failed to delete route %s: %w", routeID, err) @@ -374,16 +382,15 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID func toProtocolRoute(route *route.Route) *proto.Route { return &proto.Route{ - ID: string(route.ID), - NetID: string(route.NetID), - Network: route.Network.String(), - Domains: route.Domains.ToPunycodeList(), - NetworkType: int64(route.NetworkType), - Peer: route.Peer, - Metric: int64(route.Metric), - Masquerade: route.Masquerade, - KeepRoute: route.KeepRoute, - SkipAutoApply: route.SkipAutoApply, + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, } } diff --git a/management/server/route_test.go b/management/server/route_test.go index aeeeb736b..6c61fdf9c 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -69,7 +69,6 @@ func TestCreateRoute(t *testing.T) { enabled bool groups []string accessControlGroups []string - skipAutoApply bool } testCases := []struct { @@ -445,13 +444,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false, true) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false, true) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute, testCase.inputArgs.skipAutoApply) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -1085,7 +1084,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1177,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -2005,7 +2004,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, - route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply, + route.Groups, []string{}, true, userID, route.KeepRoute, ) require.NoError(t, err) @@ -2041,7 +2040,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, - route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply, + route.Groups, []string{}, true, userID, route.KeepRoute, ) require.NoError(t, err) @@ -2077,7 +2076,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { newRoute, err := manager.CreateRoute( context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, - baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, !baseRoute.SkipAutoApply, + baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, ) require.NoError(t, err) baseRoute = *newRoute @@ -2143,7 +2142,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, - newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) require.NoError(t, err) @@ -2183,7 +2182,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { _, err := manager.CreateRoute( context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, - newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) require.NoError(t, err) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 8d0509871..71915b4a2 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -55,6 +55,8 @@ type SetupKeyUpdateOperation struct { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create) if err != nil { @@ -105,6 +107,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index a820d99a9..f27eddb2f 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -52,6 +52,7 @@ const ( // SqlStore represents an account storage backed by a Sql DB persisted to disk type SqlStore struct { db *gorm.DB + resourceLocks sync.Map globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int @@ -218,6 +219,44 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { return unlock } +// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock +func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID) + + startWait := time.Now() + value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.Lock() + log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait)) + startHold := time.Now() + + unlock = func() { + mtx.Unlock() + log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold)) + } + + return unlock +} + +// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock +func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID) + + startWait := time.Now() + value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.RLock() + log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait)) + startHold := time.Now() + + unlock = func() { + mtx.RUnlock() + log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold)) + } + + return unlock +} + // Deprecated: Full account operations are no longer supported func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { start := time.Now() @@ -989,7 +1028,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { var account types.Account - result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account) + result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account) if result.Error != nil { return "", status.NewGetAccountFromStoreError(result.Error) } @@ -1474,7 +1513,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI PeerID: peerID, } - err := s.db.Clauses(clause.OnConflict{ + err := s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, DoNothing: true, }).Create(peer).Error @@ -1489,7 +1528,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI // RemovePeerFromGroup removes a peer from a group func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { - err := s.db. + err := s.db.WithContext(ctx). Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error if err != nil { @@ -1502,7 +1541,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group // RemovePeerFromAllGroups removes a peer from all groups func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { - err := s.db. + err := s.db.WithContext(ctx). Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error if err != nil { @@ -2090,7 +2129,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } @@ -2781,7 +2820,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) { - tx := s.db + tx := s.db.WithContext(ctx) if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -2922,22 +2961,3 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i } return nil } - -func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { - if len(groupIDs) == 0 { - return []*nbpeer.Peer{}, nil - } - - var peers []*nbpeer.Peer - peerIDsSubquery := s.db.Model(&types.GroupPeer{}). - Select("DISTINCT peer_id"). - Where("account_id = ? AND group_id IN ?", accountID, groupIDs) - - result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers) - if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get peers by group IDs") - } - - return peers, nil -} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index d40c4664c..935b0a595 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3607,113 +3607,3 @@ func intToIPv4(n uint32) net.IP { binary.BigEndian.PutUint32(ip, n) return ip } - -func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { - accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - - group1ID := "test-group-1" - group2ID := "test-group-2" - emptyGroupID := "empty-group" - - peer1 := "cfefqs706sqkneg59g4g" - peer2 := "cfeg6sf06sqkneg59g50" - - tests := []struct { - name string - groupIDs []string - expectedPeers []string - expectedCount int - }{ - { - name: "retrieve peers from single group with multiple peers", - groupIDs: []string{group1ID}, - expectedPeers: []string{peer1, peer2}, - expectedCount: 2, - }, - { - name: "retrieve peers from single group with one peer", - groupIDs: []string{group2ID}, - expectedPeers: []string{peer1}, - expectedCount: 1, - }, - { - name: "retrieve peers from multiple groups (with overlap)", - groupIDs: []string{group1ID, group2ID}, - expectedPeers: []string{peer1, peer2}, // should deduplicate - expectedCount: 2, - }, - { - name: "retrieve peers from existing 'All' group", - groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data - expectedPeers: []string{peer1, peer2}, - expectedCount: 2, - }, - { - name: "retrieve peers from empty group", - groupIDs: []string{emptyGroupID}, - expectedPeers: []string{}, - expectedCount: 0, - }, - { - name: "retrieve peers from non-existing group", - groupIDs: []string{"non-existing-group"}, - expectedPeers: []string{}, - expectedCount: 0, - }, - { - name: "empty group IDs list", - groupIDs: []string{}, - expectedPeers: []string{}, - expectedCount: 0, - }, - { - name: "mix of existing and non-existing groups", - groupIDs: []string{group1ID, "non-existing-group"}, - expectedPeers: []string{peer1, peer2}, - expectedCount: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) - t.Cleanup(cleanup) - require.NoError(t, err) - - ctx := context.Background() - - groups := []*types.Group{ - { - ID: group1ID, - AccountID: accountID, - }, - { - ID: group2ID, - AccountID: accountID, - }, - } - require.NoError(t, store.CreateGroups(ctx, accountID, groups)) - - require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID)) - require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID)) - require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID)) - - peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs) - require.NoError(t, err) - require.Len(t, peers, tt.expectedCount) - - if tt.expectedCount > 0 { - actualPeerIDs := make([]string, len(peers)) - for i, peer := range peers { - actualPeerIDs[i] = peer.ID - } - assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs) - - // Verify all returned peers belong to the correct account - for _, peer := range peers { - assert.Equal(t, accountID, peer.AccountID) - } - } - }) - } -} diff --git a/management/server/store/store.go b/management/server/store/store.go index 31d027c36..d8566e086 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -136,7 +136,6 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) - GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) @@ -169,6 +168,10 @@ type Store interface { GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error + // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock + AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() + // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock + AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index d4301802f..ac6ff2ea8 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -4,28 +4,20 @@ import ( "context" "time" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) -const AccountIDLabel = "account_id" -const HighLatencyThreshold = time.Second * 7 - // GRPCMetrics are gRPC server metrics type GRPCMetrics struct { - meter metric.Meter - syncRequestsCounter metric.Int64Counter - syncRequestsBlockedCounter metric.Int64Counter - syncRequestHighLatencyCounter metric.Int64Counter - loginRequestsCounter metric.Int64Counter - loginRequestsBlockedCounter metric.Int64Counter - loginRequestHighLatencyCounter metric.Int64Counter - getKeyRequestsCounter metric.Int64Counter - activeStreamsGauge metric.Int64ObservableGauge - syncRequestDuration metric.Int64Histogram - loginRequestDuration metric.Int64Histogram - channelQueueLength metric.Int64Histogram - ctx context.Context + meter metric.Meter + syncRequestsCounter metric.Int64Counter + loginRequestsCounter metric.Int64Counter + getKeyRequestsCounter metric.Int64Counter + activeStreamsGauge metric.Int64ObservableGauge + syncRequestDuration metric.Int64Histogram + loginRequestDuration metric.Int64Histogram + channelQueueLength metric.Int64Histogram + ctx context.Context } // NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server @@ -38,22 +30,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } - syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of sync gRPC requests from blocked peers"), - ) - if err != nil { - return nil, err - } - - syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), - ) - if err != nil { - return nil, err - } - loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -62,22 +38,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } - loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of login gRPC requests from blocked peers"), - ) - if err != nil { - return nil, err - } - - loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter", - metric.WithUnit("1"), - metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"), - ) - if err != nil { - return nil, err - } - getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"), @@ -123,19 +83,15 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro } return &GRPCMetrics{ - meter: meter, - syncRequestsCounter: syncRequestsCounter, - syncRequestsBlockedCounter: syncRequestsBlockedCounter, - syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, - loginRequestsCounter: loginRequestsCounter, - loginRequestsBlockedCounter: loginRequestsBlockedCounter, - loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, - getKeyRequestsCounter: getKeyRequestsCounter, - activeStreamsGauge: activeStreamsGauge, - syncRequestDuration: syncRequestDuration, - loginRequestDuration: loginRequestDuration, - channelQueueLength: channelQueue, - ctx: ctx, + meter: meter, + syncRequestsCounter: syncRequestsCounter, + loginRequestsCounter: loginRequestsCounter, + getKeyRequestsCounter: getKeyRequestsCounter, + activeStreamsGauge: activeStreamsGauge, + syncRequestDuration: syncRequestDuration, + loginRequestDuration: loginRequestDuration, + channelQueueLength: channelQueue, + ctx: ctx, }, err } @@ -144,11 +100,6 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() { grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1) } -// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers -func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() { - grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) -} - // CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() { grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1) @@ -159,25 +110,14 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() { grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1) } -// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers -func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() { - grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) -} - // CountLoginRequestDuration counts the duration of the login gRPC requests -func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { +func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) { grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) - if duration > HighLatencyThreshold { - grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) - } } // CountSyncRequestDuration counts the duration of the sync gRPC requests -func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { +func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) { grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) - if duration > HighLatencyThreshold { - grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) - } } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. diff --git a/management/server/types/account.go b/management/server/types/account.go index a69d3bb08..9ac2568a0 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -300,12 +300,9 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone + if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) - zones = append(zones, nbdns.CustomZone{ - Domain: peersCustomZone.Domain, - Records: records, - }) + zones = append(zones, peersCustomZone) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -1654,24 +1651,3 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } - -// filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { - filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) - peerIPs := make(map[string]struct{}) - - // Add peer's own IP to include its own DNS records - peerIPs[peer.IP.String()] = struct{}{} - - for _, peerToConnect := range peersToConnect { - peerIPs[peerToConnect.IP.String()] = struct{}{} - } - - for _, record := range customZone.Records { - if _, exists := peerIPs[record.RData]; exists { - filteredRecords = append(filteredRecords, record) - } - } - - return filteredRecords -} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..f8ab1d627 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,17 +2,14 @@ package types import ( "context" - "fmt" "net" "net/netip" "slices" "testing" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -838,109 +835,3 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } - -func Test_FilterZoneRecordsForPeers(t *testing.T) { - tests := []struct { - name string - peer *nbpeer.Peer - customZone nbdns.CustomZone - peersToConnect []*nbpeer.Peer - expectedRecords []nbdns.SimpleRecord - }{ - { - name: "empty peers to connect", - customZone: nbdns.CustomZone{ - Domain: "netbird.cloud.", - Records: []nbdns.SimpleRecord{ - {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, - }, - }, - peersToConnect: []*nbpeer.Peer{}, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, - expectedRecords: []nbdns.SimpleRecord{ - {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, - }, - }, - { - name: "multiple peers multiple records match", - customZone: nbdns.CustomZone{ - Domain: "netbird.cloud.", - Records: func() []nbdns.SimpleRecord { - var records []nbdns.SimpleRecord - for i := 1; i <= 100; i++ { - records = append(records, nbdns.SimpleRecord{ - Name: fmt.Sprintf("peer%d.netbird.cloud", i), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: 300, - RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), - }) - } - return records - }(), - }, - peersToConnect: func() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { - peers = append(peers, &nbpeer.Peer{ - ID: fmt.Sprintf("peer%d", i), - IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), - }) - } - return peers - }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, - expectedRecords: func() []nbdns.SimpleRecord { - var records []nbdns.SimpleRecord - for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { - records = append(records, nbdns.SimpleRecord{ - Name: fmt.Sprintf("peer%d.netbird.cloud", i), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: 300, - RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), - }) - } - return records - }(), - }, - { - name: "peers with multiple DNS labels", - customZone: nbdns.CustomZone{ - Domain: "netbird.cloud.", - Records: []nbdns.SimpleRecord{ - {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, - {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, - {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, - {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, - {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, - }, - }, - peersToConnect: []*nbpeer.Peer{ - {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, - {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, - }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, - expectedRecords: []nbdns.SimpleRecord{ - {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, - {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, - {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) - assert.Equal(t, len(tt.expectedRecords), len(result)) - assert.ElementsMatch(t, tt.expectedRecords, result) - }) - } -} diff --git a/management/server/types/network.go b/management/server/types/network.go index ffc019565..f072a4294 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -12,11 +12,11 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/shared/management/proto" - "github.com/netbirdio/netbird/shared/management/status" ) const ( diff --git a/management/server/types/settings.go b/management/server/types/settings.go index b4afb2f5e..56c33da3b 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -83,9 +83,6 @@ type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool - // UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator - UserApprovalRequired bool - // IntegratedValidator is the string enum for the integrated validator type IntegratedValidator string // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations @@ -102,7 +99,6 @@ type ExtraSettings struct { func (e *ExtraSettings) Copy() *ExtraSettings { return &ExtraSettings{ PeerApprovalEnabled: e.PeerApprovalEnabled, - UserApprovalRequired: e.UserApprovalRequired, IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups), IntegratedValidator: e.IntegratedValidator, FlowEnabled: e.FlowEnabled, diff --git a/management/server/types/user.go b/management/server/types/user.go index beb3586df..783fe14da 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -64,7 +64,6 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` - PendingApproval bool `json:"pending_approval"` IntegrationReference integration_reference.IntegrationReference `json:"-"` } @@ -85,8 +84,6 @@ type User struct { PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool - // PendingApproval indicates whether the user requires approval before being activated - PendingApproval bool // LastLogin is the last time the user logged in to IdP LastLogin *time.Time // CreatedAt records the time the user was created @@ -144,17 +141,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { if userData == nil { return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, - PendingApproval: u.PendingApproval, + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, }, nil } if userData.ID != u.Id { @@ -167,17 +163,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { } return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, - PendingApproval: u.PendingApproval, + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, }, nil } @@ -199,7 +194,6 @@ func (u *User) Copy() *User { ServiceUserName: u.ServiceUserName, PATs: pats, Blocked: u.Blocked, - PendingApproval: u.PendingApproval, LastLogin: u.LastLogin, CreatedAt: u.CreatedAt, Issued: u.Issued, diff --git a/management/server/user.go b/management/server/user.go index d40d33c6a..ba1835f22 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -26,6 +26,9 @@ import ( // createServiceUser creates a new service user under the given account. func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -73,6 +76,9 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user // inviteNewUser Invites a USer to a given account and creates reference in datastore func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if am.idpManager == nil { return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } @@ -221,6 +227,9 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err @@ -276,6 +285,9 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if am.idpManager == nil { return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } @@ -316,6 +328,9 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + if tokenName == "" { return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } @@ -364,6 +379,9 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -463,6 +481,9 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists) if err != nil { return nil, err @@ -519,46 +540,33 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUser = result } - var globalErr error - for _, update := range updates { - if update == nil { - return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + for _, update := range updates { + if update == nil { + return status.Errorf(status.InvalidArgument, "provided user update is nil") + } - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - - if userHadPeers { - updateAccountPeers = true - } - - err = transaction.SaveUser(ctx, updatedUser) - if err != nil { - return fmt.Errorf("failed to save updated user %s: %w", update.Id, err) - } - usersToSave = append(usersToSave, updatedUser) addUserEvents = append(addUserEvents, userEvents...) peersToExpire = append(peersToExpire, userPeersToExpire...) - return nil - }) - if err != nil { - log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err) - if len(updates) == 1 { - return nil, err + if userHadPeers { + updateAccountPeers = true } - globalErr = errors.Join(globalErr, err) - // continue when updating multiple users } + return transaction.SaveUsers(ctx, usersToSave) + }) + if err != nil { + return nil, err } - var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave)) + var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates)) userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID) if err != nil { @@ -591,7 +599,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, am.UpdateAccountPeers(ctx, accountID) } - return updatedUsersInfo, globalErr + return updatedUsersInfo, nil } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. @@ -656,7 +664,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } transferredOwnerRole = result - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) if err != nil { return false, nil, nil, nil, err } @@ -942,11 +950,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) - if peer.UserID == "" { - // we do not want to expire peers that are added via setup key - continue - } - if peer.Status.LoginExpired { continue } @@ -965,7 +968,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service - log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1213,77 +1215,3 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut return userWithPermissions, nil } - -// ApproveUser approves a user that is pending approval -func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) - if err != nil { - return nil, err - } - - if user.AccountID != accountID { - return nil, status.NewUserNotFoundError(targetUserID) - } - - if !user.PendingApproval { - return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) - } - - user.Blocked = false - user.PendingApproval = false - - err = am.Store.SaveUser(ctx, user) - if err != nil { - return nil, err - } - - am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil) - - userInfo, err := am.getUserInfo(ctx, user, accountID) - if err != nil { - return nil, err - } - - return userInfo, nil -} - -// RejectUser rejects a user that is pending approval by deleting them -func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.NewUserNotFoundError(targetUserID) - } - - if !user.PendingApproval { - return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) - } - - err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID) - if err != nil { - return err - } - - am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil) - - return nil -} diff --git a/management/server/user_test.go b/management/server/user_test.go index 9638559f9..8ab0c1565 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1746,117 +1746,3 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { return permissions } - -func TestApproveUser(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create account with admin and pending approval user - account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Create admin user - adminUser := types.NewAdminUser("admin-user") - adminUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), adminUser) - require.NoError(t, err) - - // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") - pendingUser.AccountID = account.Id - pendingUser.Blocked = true - pendingUser.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser) - require.NoError(t, err) - - // Test successful approval - approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) - require.NoError(t, err) - assert.False(t, approvedUser.IsBlocked) - assert.False(t, approvedUser.PendingApproval) - - // Verify user is updated in store - updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) - require.NoError(t, err) - assert.False(t, updatedUser.Blocked) - assert.False(t, updatedUser.PendingApproval) - - // Test approval of non-pending user should fail - _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) - require.Error(t, err) - assert.Contains(t, err.Error(), "not pending approval") - - // Test approval by non-admin should fail - regularUser := types.NewRegularUser("regular-user") - regularUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), regularUser) - require.NoError(t, err) - - pendingUser2 := types.NewRegularUser("pending-user-2") - pendingUser2.AccountID = account.Id - pendingUser2.Blocked = true - pendingUser2.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser2) - require.NoError(t, err) - - _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) - require.Error(t, err) -} - -func TestRejectUser(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - } - - // Create account with admin and pending approval user - account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) - err = manager.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - // Create admin user - adminUser := types.NewAdminUser("admin-user") - adminUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), adminUser) - require.NoError(t, err) - - // Create user pending approval - pendingUser := types.NewRegularUser("pending-user") - pendingUser.AccountID = account.Id - pendingUser.Blocked = true - pendingUser.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser) - require.NoError(t, err) - - // Test successful rejection - err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) - require.NoError(t, err) - - // Verify user is deleted from store - _, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) - require.Error(t, err) - - // Test rejection of non-pending user should fail - regularUser := types.NewRegularUser("regular-user") - regularUser.AccountID = account.Id - err = manager.Store.SaveUser(context.Background(), regularUser) - require.NoError(t, err) - - err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id) - require.Error(t, err) - assert.Contains(t, err.Error(), "not pending approval") - - // Test rejection by non-admin should fail - pendingUser2 := types.NewRegularUser("pending-user-2") - pendingUser2.AccountID = account.Id - pendingUser2.Blocked = true - pendingUser2.PendingApproval = true - err = manager.Store.SaveUser(context.Background(), pendingUser2) - require.NoError(t, err) - - err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) - require.Error(t, err) -} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 12219e29b..332127660 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -73,12 +73,7 @@ func (l *Listener) Shutdown(ctx context.Context) error { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { connRemoteAddr := remoteAddr(r) - - acceptOptions := &websocket.AcceptOptions{ - OriginPatterns: []string{"*"}, - } - - wsConn, err := websocket.Accept(w, r, acceptOptions) + wsConn, err := websocket.Accept(w, r, nil) if err != nil { log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err) return diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go index 4dfea6da1..6b8a6f701 100644 --- a/relay/test/benchmark_test.go +++ b/relay/test/benchmark_test.go @@ -13,11 +13,10 @@ import ( "github.com/pion/logging" "github.com/pion/turn/v3" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay/auth/allow" "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/util" ) @@ -101,7 +100,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i), iface.DefaultMTU) + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -111,7 +110,7 @@ func transfer(t *testing.T, testData []byte, peerPairs int) { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i), iface.DefaultMTU) + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) err := c.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go index e6924061f..aa0fc662a 100644 --- a/relay/testec2/relay.go +++ b/relay/testec2/relay.go @@ -11,7 +11,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client" ) @@ -71,7 +70,7 @@ func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { ctx := context.Background() clientsSender := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsSender); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i), iface.DefaultMTU) + c := client.NewClient(serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) if err := c.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -157,7 +156,7 @@ func runReader(conn net.Conn) time.Duration { func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { clientsReceiver := make([]*client.Client, peerPairs) for i := 0; i < cap(clientsReceiver); i++ { - c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i), iface.DefaultMTU) + c := client.NewClient(serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) err := c.Connect(context.Background()) if err != nil { log.Fatalf("failed to connect to server: %s", err) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..856d332cb 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -130,6 +130,36 @@ repo_gpgcheck=1 EOF } +install_aur_package() { + INSTALL_PKGS="git base-devel go" + REMOVE_PKGS="" + + # Check if dependencies are installed + for PKG in $INSTALL_PKGS; do + if ! pacman -Q "$PKG" > /dev/null 2>&1; then + # Install missing package(s) + ${SUDO} pacman -S "$PKG" --noconfirm + + # Add installed package for clean up later + REMOVE_PKGS="$REMOVE_PKGS $PKG" + fi + done + + # Build package from AUR + cd /tmp && git clone https://aur.archlinux.org/netbird.git + cd netbird && makepkg -sri --noconfirm + + if ! $SKIP_UI_APP; then + cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git + cd netbird-ui && makepkg -sri --noconfirm + fi + + if [ -n "$REMOVE_PKGS" ]; then + # Clean up the installed packages + ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm + fi +} + prepare_tun_module() { # Create the necessary file structure for /dev/net/tun if [ ! -c /dev/net/tun ]; then @@ -246,9 +276,12 @@ install_netbird() { if ! $SKIP_UI_APP; then ${SUDO} rpm-ostree -y install netbird-ui fi - # ensure the service is started after install - ${SUDO} netbird service install || true - ${SUDO} netbird service start || true + ;; + pacman) + ${SUDO} pacman -Syy + install_aur_package + # in-line with the docs at https://wiki.archlinux.org/title/Netbird + ${SUDO} systemctl enable --now netbird@main.service ;; pkg) # Check if the package is already installed @@ -425,7 +458,11 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" + elif [ -x "$(command -v pacman)" ]; then + PACKAGE_MANAGER="pacman" + echo "The installation will be performed using pacman package manager" fi + else echo "Unable to determine OS type from /etc/os-release" exit 1 diff --git a/route/route.go b/route/route.go index 08a2d37dc..604f8c60f 100644 --- a/route/route.go +++ b/route/route.go @@ -107,8 +107,6 @@ type Route struct { Enabled bool Groups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"` - // SkipAutoApply indicates if this exit node route (0.0.0.0/0) should skip auto-application for client routing - SkipAutoApply bool } // EventMeta returns activity event meta related to the route @@ -138,7 +136,6 @@ func (r *Route) Copy() *Route { Enabled: r.Enabled, Groups: slices.Clone(r.Groups), AccessControlGroups: slices.Clone(r.AccessControlGroups), - SkipAutoApply: r.SkipAutoApply, } return route } @@ -165,8 +162,7 @@ func (r *Route) Equal(other *Route) bool { other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && slices.Equal(r.PeerGroups, other.PeerGroups) && - slices.Equal(r.AccessControlGroups, other.AccessControlGroups) && - other.SkipAutoApply == r.SkipAutoApply + slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } // IsDynamic returns if the route is dynamic, i.e. has domains diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 5736a16e1..e38ce9b2f 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -9,30 +9,34 @@ import ( "time" "github.com/golang/mock/gomock" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/server/config" - mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/management-integrations/integrations" + + "github.com/netbirdio/netbird/encryption" + mgmt "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/mock_server" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/util" ) @@ -69,31 +73,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) jobManager := mgmt.NewJobManager(nil, store) eventStore := &activity.InMemoryEventStore{} - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - permissionsManagerMock := permissions.NewMockManager(ctrl) - permissionsManagerMock. - EXPECT(). - ValidateUserPermissions( - gomock.Any(), - gomock.Any(), - gomock.Any(), - gomock.Any(), - gomock.Any(), - ). - Return(true, nil). - AnyTimes() - - peersManger := peers.NewManager(store, permissionsManagerMock) - settingsManagerMock := settings.NewMockManager(ctrl) - - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager. EXPECT(). @@ -124,7 +110,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { AnyTimes() accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) - if err != nil { t.Fatal(err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index b3fd28e9c..f5759ef21 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -18,11 +18,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -53,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go index 54a0290d0..56c859652 100644 --- a/shared/management/client/rest/client_test.go +++ b/shared/management/client/rest/client_test.go @@ -8,8 +8,8 @@ import ( "net/http/httptest" "testing" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) func withMockClient(callback func(*rest.Client, *http.ServeMux)) { @@ -26,7 +26,7 @@ func ptr[T any, PT *T](x T) PT { func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) { t.Helper() - handler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) + handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) server := httptest.NewServer(handler) defer server.Close() c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp") diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 942156ad2..86082e606 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -278,10 +278,6 @@ components: description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. type: boolean example: true - user_approval_required: - description: Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. - type: boolean - example: false network_traffic_logs_enabled: description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. type: boolean @@ -298,7 +294,6 @@ components: example: true required: - peer_approval_enabled - - user_approval_required - network_traffic_logs_enabled - network_traffic_logs_groups - network_traffic_packet_counter_enabled @@ -360,10 +355,6 @@ components: description: Is true if this user is blocked. Blocked users can't use the system type: boolean example: false - pending_approval: - description: Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. - type: boolean - example: false issued: description: How user was issued by API or Integration type: string @@ -378,7 +369,6 @@ components: - auto_groups - status - is_blocked - - pending_approval UserPermissions: type: object properties: @@ -1472,10 +1462,6 @@ components: items: type: string example: "chacbco6lnnbn6cg5s91" - skip_auto_apply: - description: Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing - type: boolean - example: false required: - id - description @@ -2778,63 +2764,6 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" - /api/users/{userId}/approve: - post: - summary: Approve user - description: Approve a user that is pending approval - tags: [ Users ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - parameters: - - in: path - name: userId - required: true - schema: - type: string - description: The unique identifier of a user - responses: - '200': - description: Returns the approved user - content: - application/json: - schema: - "$ref": "#/components/schemas/User" - '400': - "$ref": "#/components/responses/bad_request" - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" - /api/users/{userId}/reject: - delete: - summary: Reject user - description: Reject a user that is pending approval by removing them from the account - tags: [ Users ] - security: - - BearerAuth: [ ] - - TokenAuth: [ ] - parameters: - - in: path - name: userId - required: true - schema: - type: string - description: The unique identifier of a user - responses: - '200': - description: User rejected successfully - content: {} - '400': - "$ref": "#/components/responses/bad_request" - '401': - "$ref": "#/components/responses/requires_authentication" - '403': - "$ref": "#/components/responses/forbidden" - '500': - "$ref": "#/components/responses/internal_error" /api/users/current: get: summary: Retrieve current user diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 883ce4928..09c1ecb0f 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -284,9 +284,6 @@ type AccountExtraSettings struct { // PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. PeerApprovalEnabled bool `json:"peer_approval_enabled"` - - // UserApprovalRequired Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. - UserApprovalRequired bool `json:"user_approval_required"` } // AccountOnboarding defines model for AccountOnboarding. @@ -1622,9 +1619,6 @@ type Route struct { // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` PeerGroups *[]string `json:"peer_groups,omitempty"` - - // SkipAutoApply Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing - SkipAutoApply *bool `json:"skip_auto_apply,omitempty"` } // RouteRequest defines model for RouteRequest. @@ -1664,9 +1658,6 @@ type RouteRequest struct { // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` PeerGroups *[]string `json:"peer_groups,omitempty"` - - // SkipAutoApply Indicate if this exit node route (0.0.0.0/0) should skip auto-application for client routing - SkipAutoApply *bool `json:"skip_auto_apply,omitempty"` } // RulePortRange Policy rule affected ports range @@ -1855,11 +1846,8 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` - - // PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. - PendingApproval bool `json:"pending_approval"` - Permissions *UserPermissions `json:"permissions,omitempty"` + Name string `json:"name"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 97e06f6a1..5ee9565df 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v3.21.12 // source: management.proto package proto @@ -1982,7 +1982,6 @@ type PeerConfig struct { Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` LazyConnectionEnabled bool `protobuf:"varint,6,opt,name=LazyConnectionEnabled,proto3" json:"LazyConnectionEnabled,omitempty"` - Mtu int32 `protobuf:"varint,7,opt,name=mtu,proto3" json:"mtu,omitempty"` } func (x *PeerConfig) Reset() { @@ -2059,13 +2058,6 @@ func (x *PeerConfig) GetLazyConnectionEnabled() bool { return false } -func (x *PeerConfig) GetMtu() int32 { - if x != nil { - return x.Mtu - } - return 0 -} - // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -2701,16 +2693,15 @@ type Route struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"` - NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"` - Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"` - Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"` - Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"` - NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"` - Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"` - KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"` - SkipAutoApply bool `protobuf:"varint,10,opt,name=skipAutoApply,proto3" json:"skipAutoApply,omitempty"` + ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` + Network string `protobuf:"bytes,2,opt,name=Network,proto3" json:"Network,omitempty"` + NetworkType int64 `protobuf:"varint,3,opt,name=NetworkType,proto3" json:"NetworkType,omitempty"` + Peer string `protobuf:"bytes,4,opt,name=Peer,proto3" json:"Peer,omitempty"` + Metric int64 `protobuf:"varint,5,opt,name=Metric,proto3" json:"Metric,omitempty"` + Masquerade bool `protobuf:"varint,6,opt,name=Masquerade,proto3" json:"Masquerade,omitempty"` + NetID string `protobuf:"bytes,7,opt,name=NetID,proto3" json:"NetID,omitempty"` + Domains []string `protobuf:"bytes,8,rep,name=Domains,proto3" json:"Domains,omitempty"` + KeepRoute bool `protobuf:"varint,9,opt,name=keepRoute,proto3" json:"keepRoute,omitempty"` } func (x *Route) Reset() { @@ -2808,13 +2799,6 @@ func (x *Route) GetKeepRoute() bool { return false } -func (x *Route) GetSkipAutoApply() bool { - if x != nil { - return x.SkipAutoApply - } - return false -} - // DNSConfig represents a dns.Update type DNSConfig struct { state protoimpl.MessageState diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ae5a1b29d..4aaefb5aa 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -303,8 +303,6 @@ message PeerConfig { bool RoutingPeerDnsResolutionEnabled = 5; bool LazyConnectionEnabled = 6; - - int32 mtu = 7; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -441,7 +439,6 @@ message Route { string NetID = 7; repeated string Domains = 8; bool keepRoute = 9; - bool skipAutoApply = 10; } // DNSConfig represents a dns.Update diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1e914babb..7660174d6 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -42,10 +42,7 @@ const ( // Type is a type of the Error type Type int32 -var ( - ErrExtraSettingsNotFound = errors.New("extra settings not found") - ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in") -) +var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found") // Error is an internal error type Error struct { @@ -113,11 +110,6 @@ func NewUserBlockedError() error { return Errorf(PermissionDenied, "user is blocked") } -// NewUserPendingApprovalError creates a new Error with PermissionDenied type for a blocked user pending approval -func NewUserPendingApprovalError() error { - return Errorf(PermissionDenied, "user is pending approval") -} - // NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer func NewPeerNotRegisteredError() error { return Errorf(Unauthenticated, "peer is not registered") diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 5dabc5742..37c9debc2 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -9,7 +9,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client/dialer" "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" @@ -144,12 +143,10 @@ type Client struct { listenerMutex sync.Mutex stateSubscription *PeersStateSubscription - - mtu uint16 } // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect -func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, mtu uint16) *Client { +func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { hashedID := messages.HashID(peerID) relayLog := log.WithFields(log.Fields{"relay": serverURL}) @@ -158,7 +155,6 @@ func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string, connectionURL: serverURL, authTokenStore: authTokenStore, hashedID: hashedID, - mtu: mtu, bufPool: &sync.Pool{ New: func() any { buf := make([]byte, bufferSize) @@ -296,16 +292,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - // Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues - var dialers []dialer.DialeFn - if c.mtu > 0 && c.mtu > iface.DefaultMTU { - c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) - dialers = []dialer.DialeFn{ws.Dialer{}} - } else { - dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} - } - - rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) + rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{}) conn, err := rd.Dial() if err != nil { return nil, err diff --git a/shared/relay/client/client_test.go b/shared/relay/client/client_test.go index 8fe5f04f4..c7c5fbf2b 100644 --- a/shared/relay/client/client_test.go +++ b/shared/relay/client/client_test.go @@ -10,7 +10,6 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/shared/relay/auth/allow" "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/util" @@ -64,7 +63,7 @@ func TestClient(t *testing.T) { t.Fatalf("failed to start server: %s", err) } t.Log("alice connecting to server") - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -72,7 +71,7 @@ func TestClient(t *testing.T) { defer clientAlice.Close() t.Log("placeholder connecting to server") - clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder", iface.DefaultMTU) + clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder") err = clientPlaceHolder.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -80,7 +79,7 @@ func TestClient(t *testing.T) { defer clientPlaceHolder.Close() t.Log("Bob connecting to server") - clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + clientBob := NewClient(serverURL, hmacTokenStore, "bob") err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -138,7 +137,7 @@ func TestRegistration(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { _ = srv.Shutdown(ctx) @@ -178,7 +177,7 @@ func TestRegistrationTimeout(t *testing.T) { _ = fakeTCPListener.Close() }(fakeTCPListener) - clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err == nil { t.Errorf("failed to connect to server: %s", err) @@ -219,7 +218,7 @@ func TestEcho(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -231,7 +230,7 @@ func TestEcho(t *testing.T) { } }() - clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU) + clientBob := NewClient(serverURL, hmacTokenStore, idBob) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -309,7 +308,7 @@ func TestBindToUnavailabePeer(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -355,13 +354,13 @@ func TestBindReconnect(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + clientBob := NewClient(serverURL, hmacTokenStore, "bob") err = clientBob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -383,7 +382,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to close client: %s", err) } - clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice = NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -456,13 +455,13 @@ func TestCloseConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + bob := NewClient(serverURL, hmacTokenStore, "bob") err = bob.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { t.Errorf("failed to connect to server: %s", err) @@ -518,13 +517,13 @@ func TestCloseRelayConn(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - bob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + bob := NewClient(serverURL, hmacTokenStore, "bob") err = bob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, "alice") err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -572,7 +571,7 @@ func TestCloseByServer(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + relayClient := NewClient(serverURL, hmacTokenStore, idAlice) if err = relayClient.Connect(ctx); err != nil { log.Fatalf("failed to connect to server: %s", err) } @@ -628,7 +627,7 @@ func TestCloseByClient(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - relayClient := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + relayClient := NewClient(serverURL, hmacTokenStore, idAlice) err = relayClient.Connect(ctx) if err != nil { log.Fatalf("failed to connect to server: %s", err) @@ -679,7 +678,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { t.Fatalf("failed to start server: %s", err) } - clientAlice := NewClient(serverURL, hmacTokenStore, idAlice, iface.DefaultMTU) + clientAlice := NewClient(serverURL, hmacTokenStore, idAlice) err = clientAlice.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) @@ -691,7 +690,7 @@ func TestCloseNotDrainedChannel(t *testing.T) { } }() - clientBob := NewClient(serverURL, hmacTokenStore, idBob, iface.DefaultMTU) + clientBob := NewClient(serverURL, hmacTokenStore, idBob) err = clientBob.Connect(ctx) if err != nil { t.Fatalf("failed to connect to server: %s", err) diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 967e18d79..b496f6a9b 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" quictls "github.com/netbirdio/netbird/shared/relay/tls" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index ef6bd6b3c..109651f5d 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) type Dialer struct { diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index 6220e7f6b..f3428f255 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -63,25 +63,20 @@ type Manager struct { onDisconnectedListeners map[string]*list.List onReconnectedListenerFn func() listenerLock sync.Mutex - - mtu uint16 } // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. -func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager { +func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { tokenStore := &relayAuth.TokenStore{} m := &Manager{ ctx: ctx, peerID: peerID, tokenStore: tokenStore, - mtu: mtu, serverPicker: &ServerPicker{ - TokenStore: tokenStore, - PeerID: peerID, - MTU: mtu, - ConnectionTimeout: defaultConnectionTimeout, + TokenStore: tokenStore, + PeerID: peerID, }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), @@ -258,7 +253,7 @@ func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string m.relayClients[serverAddress] = rt m.relayClientsMutex.Unlock() - relayClient := NewClient(serverAddress, m.tokenStore, m.peerID, m.mtu) + relayClient := NewClient(serverAddress, m.tokenStore, m.peerID) err := relayClient.Connect(m.ctx) if err != nil { rt.err = err diff --git a/shared/relay/client/manager_test.go b/shared/relay/client/manager_test.go index f00b35707..674555ff4 100644 --- a/shared/relay/client/manager_test.go +++ b/shared/relay/client/manager_test.go @@ -8,15 +8,14 @@ import ( log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay/auth/allow" + "github.com/netbirdio/netbird/relay/server" ) func TestEmptyURL(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mgr := NewManager(ctx, nil, "alice", iface.DefaultMTU) + mgr := NewManager(ctx, nil, "alice") err := mgr.Serve() if err == nil { t.Errorf("expected error, got nil") @@ -91,12 +90,12 @@ func TestForeignConn(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice") if err := clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU) + clientBob := NewManager(mCtx, toURL(srvCfg2), "bob") if err := clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } @@ -198,12 +197,12 @@ func TestForeginConnClose(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob", iface.DefaultMTU) + mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob") if err := mgrBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - mgr := NewManager(mCtx, toURL(srvCfg1), "alice", iface.DefaultMTU) + mgr := NewManager(mCtx, toURL(srvCfg1), "alice") err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -283,7 +282,7 @@ func TestForeignAutoClose(t *testing.T) { t.Log("connect to server 1.") mCtx, cancel := context.WithCancel(ctx) defer cancel() - mgr := NewManager(mCtx, toURL(srvCfg1), idAlice, iface.DefaultMTU) + mgr := NewManager(mCtx, toURL(srvCfg1), idAlice) err = mgr.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -354,13 +353,13 @@ func TestAutoReconnect(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientBob := NewManager(mCtx, toURL(srvCfg), "bob", iface.DefaultMTU) + clientBob := NewManager(mCtx, toURL(srvCfg), "bob") err = clientBob.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") err = clientAlice.Serve() if err != nil { t.Fatalf("failed to serve manager: %s", err) @@ -429,12 +428,12 @@ func TestNotifierDoubleAdd(t *testing.T) { mCtx, cancel := context.WithCancel(ctx) defer cancel() - clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob", iface.DefaultMTU) + clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob") if err = clientBob.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } - clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice", iface.DefaultMTU) + clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice") if err = clientAlice.Serve(); err != nil { t.Fatalf("failed to serve manager: %s", err) } diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go index 39d0ba072..1cad466ba 100644 --- a/shared/relay/client/picker.go +++ b/shared/relay/client/picker.go @@ -13,8 +13,11 @@ import ( ) const ( - maxConcurrentServers = 7 - defaultConnectionTimeout = 30 * time.Second + maxConcurrentServers = 7 +) + +var ( + connectionTimeout = 30 * time.Second ) type connResult struct { @@ -24,15 +27,13 @@ type connResult struct { } type ServerPicker struct { - TokenStore *auth.TokenStore - ServerURLs atomic.Value - PeerID string - MTU uint16 - ConnectionTimeout time.Duration + TokenStore *auth.TokenStore + ServerURLs atomic.Value + PeerID string } func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { - ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout) + ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) defer cancel() totalServers := len(sp.ServerURLs.Load().([]string)) @@ -69,7 +70,7 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { log.Infof("try to connecting to relay server: %s", url) - relayClient := NewClient(url, sp.TokenStore, sp.PeerID, sp.MTU) + relayClient := NewClient(url, sp.TokenStore, sp.PeerID) err := relayClient.Connect(ctx) resultChan <- connResult{ RelayClient: relayClient, diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go index fb3fa7375..28167c5ce 100644 --- a/shared/relay/client/picker_test.go +++ b/shared/relay/client/picker_test.go @@ -8,15 +8,15 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { - timeout := 5 * time.Second + connectionTimeout = 5 * time.Second + sp := ServerPicker{ - TokenStore: nil, - PeerID: "test", - ConnectionTimeout: timeout, + TokenStore: nil, + PeerID: "test", } sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) - ctx, cancel := context.WithTimeout(context.Background(), timeout+1) + ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go deleted file mode 100644 index 2b584c195..000000000 --- a/shared/relay/healthcheck/env.go +++ /dev/null @@ -1,24 +0,0 @@ -package healthcheck - -import ( - "os" - "strconv" - - log "github.com/sirupsen/logrus" -) - -const ( - defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" -) - -func getAttemptThresholdFromEnv() int { - if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { - threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) - if err != nil { - log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) - return defaultAttemptThreshold - } - return int(threshold) - } - return defaultAttemptThreshold -} diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go deleted file mode 100644 index 2e14bb8bf..000000000 --- a/shared/relay/healthcheck/env_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package healthcheck - -import ( - "os" - "testing" -) - -//nolint:tenv -func TestGetAttemptThresholdFromEnv(t *testing.T) { - tests := []struct { - name string - envValue string - expected int - }{ - {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, - {"Custom attempt threshold when env is set to a valid integer", "3", 3}, - {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.envValue == "" { - os.Unsetenv(defaultAttemptThresholdEnv) - } else { - os.Setenv(defaultAttemptThresholdEnv, tt.envValue) - } - - result := getAttemptThresholdFromEnv() - if result != tt.expected { - t.Fatalf("Expected %d, got %d", tt.expected, result) - } - - os.Unsetenv(defaultAttemptThresholdEnv) - }) - } -} diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go index 90f795bbe..b3503d5db 100644 --- a/shared/relay/healthcheck/receiver.go +++ b/shared/relay/healthcheck/receiver.go @@ -7,15 +7,10 @@ import ( log "github.com/sirupsen/logrus" ) -const ( - defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second +var ( + heartbeatTimeout = healthCheckInterval + 10*time.Second ) -type ReceiverOptions struct { - HeartbeatTimeout time.Duration - AttemptThreshold int -} - // Receiver is a healthcheck receiver // It will listen for heartbeat and check if the heartbeat is not received in a certain time // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work @@ -32,23 +27,6 @@ type Receiver struct { // NewReceiver creates a new healthcheck receiver and start the timer in the background func NewReceiver(log *log.Entry) *Receiver { - opts := ReceiverOptions{ - HeartbeatTimeout: defaultHeartbeatTimeout, - AttemptThreshold: getAttemptThresholdFromEnv(), - } - return NewReceiverWithOpts(log, opts) -} - -func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { - heartbeatTimeout := opts.HeartbeatTimeout - if heartbeatTimeout <= 0 { - heartbeatTimeout = defaultHeartbeatTimeout - } - attemptThreshold := opts.AttemptThreshold - if attemptThreshold <= 0 { - attemptThreshold = defaultAttemptThreshold - } - ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ @@ -57,10 +35,10 @@ func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { ctx: ctx, ctxCancel: ctxCancel, heartbeat: make(chan struct{}, 1), - attemptThreshold: attemptThreshold, + attemptThreshold: getAttemptThresholdFromEnv(), } - go r.waitForHealthcheck(heartbeatTimeout) + go r.waitForHealthcheck() return r } @@ -77,7 +55,7 @@ func (r *Receiver) Stop() { r.ctxCancel() } -func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) { +func (r *Receiver) waitForHealthcheck() { ticker := time.NewTicker(heartbeatTimeout) defer ticker.Stop() defer r.ctxCancel() diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go index b20cc5124..2794159f6 100644 --- a/shared/relay/healthcheck/receiver_test.go +++ b/shared/relay/healthcheck/receiver_test.go @@ -2,18 +2,31 @@ package healthcheck import ( "context" + "fmt" + "os" + "sync" "testing" "time" log "github.com/sirupsen/logrus" ) -func TestNewReceiver(t *testing.T) { +// Mutex to protect global variable access in tests +var testMutex sync.Mutex - opts := ReceiverOptions{ - HeartbeatTimeout: 5 * time.Second, - } - r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) +func TestNewReceiver(t *testing.T) { + testMutex.Lock() + originalTimeout := heartbeatTimeout + heartbeatTimeout = 5 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + + r := NewReceiver(log.WithContext(context.Background())) defer r.Stop() select { @@ -25,10 +38,18 @@ func TestNewReceiver(t *testing.T) { } func TestNewReceiverNotReceive(t *testing.T) { - opts := ReceiverOptions{ - HeartbeatTimeout: 1 * time.Second, - } - r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) + testMutex.Lock() + originalTimeout := heartbeatTimeout + heartbeatTimeout = 1 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + + r := NewReceiver(log.WithContext(context.Background())) defer r.Stop() select { @@ -40,10 +61,18 @@ func TestNewReceiverNotReceive(t *testing.T) { } func TestNewReceiverAck(t *testing.T) { - opts := ReceiverOptions{ - HeartbeatTimeout: 2 * time.Second, - } - r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) + testMutex.Lock() + originalTimeout := heartbeatTimeout + heartbeatTimeout = 2 * time.Second + testMutex.Unlock() + + defer func() { + testMutex.Lock() + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + + r := NewReceiver(log.WithContext(context.Background())) defer r.Stop() r.Heartbeat() @@ -68,19 +97,30 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - healthCheckInterval := 1 * time.Second + testMutex.Lock() + originalInterval := healthCheckInterval + originalTimeout := heartbeatTimeout + healthCheckInterval = 1 * time.Second + heartbeatTimeout = healthCheckInterval + 500*time.Millisecond + testMutex.Unlock() - opts := ReceiverOptions{ - HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond, - AttemptThreshold: tc.threshold, - } + defer func() { + testMutex.Lock() + healthCheckInterval = originalInterval + heartbeatTimeout = originalTimeout + testMutex.Unlock() + }() + //nolint:tenv + os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) + defer os.Unsetenv(defaultAttemptThresholdEnv) - receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts) + receiver := NewReceiver(log.WithField("test_name", tc.name)) - testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval if tc.resetCounterOnce { receiver.Heartbeat() + t.Logf("reset counter once") } select { @@ -94,6 +134,7 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { } t.Fatalf("should have timed out before %s", testTimeout) } + }) } } diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go index 771e94206..57b3015ec 100644 --- a/shared/relay/healthcheck/sender.go +++ b/shared/relay/healthcheck/sender.go @@ -2,76 +2,52 @@ package healthcheck import ( "context" + "os" + "strconv" "time" log "github.com/sirupsen/logrus" ) const ( - defaultAttemptThreshold = 1 - - defaultHealthCheckInterval = 25 * time.Second - defaultHealthCheckTimeout = 20 * time.Second + defaultAttemptThreshold = 1 + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" ) -type SenderOptions struct { - HealthCheckInterval time.Duration - HealthCheckTimeout time.Duration - AttemptThreshold int -} +var ( + healthCheckInterval = 25 * time.Second + healthCheckTimeout = 20 * time.Second +) // Sender is a healthcheck sender // It will send healthcheck signal to the receiver // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { + log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} - log *log.Entry - healthCheckInterval time.Duration - timeout time.Duration - ack chan struct{} alive bool attemptThreshold int } -func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender { - if opts.HealthCheckInterval <= 0 { - opts.HealthCheckInterval = defaultHealthCheckInterval - } - if opts.HealthCheckTimeout <= 0 { - opts.HealthCheckTimeout = defaultHealthCheckTimeout - } - if opts.AttemptThreshold <= 0 { - opts.AttemptThreshold = defaultAttemptThreshold - } +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { hc := &Sender{ - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - log: log, - healthCheckInterval: opts.HealthCheckInterval, - timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout, - ack: make(chan struct{}, 1), - attemptThreshold: opts.AttemptThreshold, + log: log, + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + ack: make(chan struct{}, 1), + attemptThreshold: getAttemptThresholdFromEnv(), } return hc } -// NewSender creates a new healthcheck sender -func NewSender(log *log.Entry) *Sender { - opts := SenderOptions{ - HealthCheckInterval: defaultHealthCheckInterval, - HealthCheckTimeout: defaultHealthCheckTimeout, - AttemptThreshold: getAttemptThresholdFromEnv(), - } - return NewSenderWithOpts(log, opts) -} - // OnHCResponse sends an acknowledgment signal to the sender func (hc *Sender) OnHCResponse() { select { @@ -81,10 +57,10 @@ func (hc *Sender) OnHCResponse() { } func (hc *Sender) StartHealthCheck(ctx context.Context) { - ticker := time.NewTicker(hc.healthCheckInterval) + ticker := time.NewTicker(healthCheckInterval) defer ticker.Stop() - timeoutTicker := time.NewTicker(hc.timeout) + timeoutTicker := time.NewTicker(hc.getTimeoutTime()) defer timeoutTicker.Stop() defer close(hc.HealthCheck) @@ -116,3 +92,19 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { } } } + +func (hc *Sender) getTimeoutTime() time.Duration { + return healthCheckInterval + healthCheckTimeout +} + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go index 122fe0f16..23446366a 100644 --- a/shared/relay/healthcheck/sender_test.go +++ b/shared/relay/healthcheck/sender_test.go @@ -2,23 +2,26 @@ package healthcheck import ( "context" + "fmt" + "os" "testing" "time" log "github.com/sirupsen/logrus" ) -var ( - testOpts = SenderOptions{ - HealthCheckInterval: 2 * time.Second, - HealthCheckTimeout: 100 * time.Millisecond, - } -) +func TestMain(m *testing.M) { + // override the health check interval to speed up the test + healthCheckInterval = 2 * time.Second + healthCheckTimeout = 100 * time.Millisecond + code := m.Run() + os.Exit(code) +} func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) iterations := 0 @@ -29,7 +32,7 @@ func TestNewHealthPeriod(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): + case <-time.After(healthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -38,19 +41,19 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) select { case <-hc.Timeout: - case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond): + case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): t.Fatalf("health check is not timed out") } } func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) iterations := 0 @@ -86,7 +89,7 @@ func TestTimeoutReset(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): + case <-time.After(healthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -115,16 +118,19 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - opts := SenderOptions{ - HealthCheckInterval: 1 * time.Second, - HealthCheckTimeout: 500 * time.Millisecond, - AttemptThreshold: tc.threshold, - } + originalInterval := healthCheckInterval + originalTimeout := healthCheckTimeout + healthCheckInterval = 1 * time.Second + healthCheckTimeout = 500 * time.Millisecond + + //nolint:tenv + os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) + defer os.Unsetenv(defaultAttemptThresholdEnv) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts) + sender := NewSender(log.WithField("test_name", tc.name)) senderExit := make(chan struct{}) go func() { sender.StartHealthCheck(ctx) @@ -149,7 +155,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { } }() - testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval + testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval select { case <-sender.Timeout: @@ -169,7 +175,39 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } + healthCheckInterval = originalInterval + healthCheckTimeout = originalTimeout }) } } + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5ca0c0282..82ab678f4 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -16,10 +16,10 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" + nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/sharedsock/example/main.go b/sharedsock/example/main.go index da62b276e..9384d2b1c 100644 --- a/sharedsock/example/main.go +++ b/sharedsock/example/main.go @@ -5,16 +5,14 @@ import ( "os" "os/signal" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/sharedsock" + log "github.com/sirupsen/logrus" ) func main() { port := 51820 - rawSock, err := sharedsock.Listen(port, sharedsock.NewIncomingSTUNFilter(), iface.DefaultMTU) + rawSock, err := sharedsock.Listen(port, sharedsock.NewIncomingSTUNFilter()) if err != nil { panic(err) } diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index bc2d4d1be..1c22e7869 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - nbnet "github.com/netbirdio/netbird/client/net" + nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -36,7 +36,6 @@ type SharedSocket struct { conn4 *socket.Conn conn6 *socket.Conn port int - mtu uint16 routerMux sync.RWMutex router routing.Router packetDemux chan rcvdPacket @@ -57,19 +56,12 @@ var writeSerializerOptions = gopacket.SerializeOptions{ FixLengths: true, } -// Maximum overhead for IP + UDP headers on raw socket -// IPv4: max 60 bytes (20 base + 40 options) + UDP 8 bytes = 68 bytes -// IPv6: 40 bytes + UDP 8 bytes = 48 bytes -// We use the maximum (68) for both IPv4 and IPv6 -const maxIPUDPOverhead = 68 - // Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines -func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error) { +func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { ctx, cancel := context.WithCancel(context.Background()) rawSock := &SharedSocket{ ctx: ctx, cancel: cancel, - mtu: mtu, port: port, packetDemux: make(chan rcvdPacket), } @@ -93,7 +85,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err) + return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) } var sockErr error @@ -102,7 +94,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err) + return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) } } @@ -231,7 +223,7 @@ func (s *SharedSocket) Close() error { // read start a read loop for a specific receiver and sends the packet to the packetDemux channel func (s *SharedSocket) read(receiver receiver) { for { - buf := make([]byte, s.mtu+maxIPUDPOverhead) + buf := make([]byte, 1500) n, addr, err := receiver(s.ctx, buf, 0) select { case <-s.ctx.Done(): diff --git a/sharedsock/sock_linux_test.go b/sharedsock/sock_linux_test.go index a22af461a..f5c85119c 100644 --- a/sharedsock/sock_linux_test.go +++ b/sharedsock/sock_linux_test.go @@ -21,7 +21,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) { // create raw socket on a port testingPort := 51821 - rawSock, err := Listen(testingPort, NewIncomingSTUNFilter(), 1280) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) err = rawSock.SetReadDeadline(time.Now().Add(3 * time.Second)) require.NoError(t, err, "unable to set deadline, error: %s", err) @@ -76,7 +76,7 @@ func TestShouldReadSTUNOnReadFrom(t *testing.T) { func TestShouldNotReadNonSTUNPackets(t *testing.T) { testingPort := 39439 - rawSock, err := Listen(testingPort, NewIncomingSTUNFilter(), 1280) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) defer rawSock.Close() @@ -110,7 +110,7 @@ func TestWriteTo(t *testing.T) { defer udpListener.Close() testingPort := 39440 - rawSock, err := Listen(testingPort, NewIncomingSTUNFilter(), 1280) + rawSock, err := Listen(testingPort, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) defer rawSock.Close() @@ -144,7 +144,7 @@ func TestWriteTo(t *testing.T) { } func TestSharedSocket_Close(t *testing.T) { - rawSock, err := Listen(39440, NewIncomingSTUNFilter(), 1280) + rawSock, err := Listen(39440, NewIncomingSTUNFilter()) require.NoError(t, err, "received an error while creating STUN listener, error: %s", err) errGrp := errgroup.Group{} diff --git a/sharedsock/sock_nolinux.go b/sharedsock/sock_nolinux.go index a92f22edf..a36ef67c6 100644 --- a/sharedsock/sock_nolinux.go +++ b/sharedsock/sock_nolinux.go @@ -9,6 +9,6 @@ import ( ) // Listen is not supported on other platforms then Linux -func Listen(port int, filter BPFFilter, mtu uint16) (net.PacketConn, error) { +func Listen(port int, filter BPFFilter) (net.PacketConn, error) { return nil, fmt.Errorf("not supported OS %s. SharedSocket is only supported on Linux", runtime.GOOS) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 1d76fa4e4..2e89b491a 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,7 +8,6 @@ import ( "fmt" "net" "net/http" - // nolint:gosec _ "net/http/pprof" "strings" diff --git a/signal/peer/peer.go b/signal/peer/peer.go index c9dd60fc0..f21c95a41 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -5,16 +5,10 @@ import ( "sync" "time" - "errors" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" -) - -var ( - ErrPeerAlreadyRegistered = errors.New("peer already registered") + "github.com/netbirdio/netbird/shared/signal/proto" ) // Peer representation of a connected Peer @@ -29,18 +23,15 @@ type Peer struct { // registration time RegisteredAt time.Time - - Cancel context.CancelFunc } // NewPeer creates a new instance of a connected Peer -func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer { +func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { return &Peer{ Id: id, Stream: stream, StreamID: time.Now().UnixNano(), RegisteredAt: time.Now(), - Cancel: cancel, } } @@ -78,24 +69,20 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool { } // Register registers peer in the registry -func (registry *Registry) Register(peer *Peer) error { +func (registry *Registry) Register(peer *Peer) { start := time.Now() + registry.regMutex.Lock() + defer registry.regMutex.Unlock() + // can be that peer already exists, but it is fine (e.g. reconnect) p, loaded := registry.Peers.LoadOrStore(peer.Id, peer) if loaded { pp := p.(*Peer) - if peer.StreamID > pp.StreamID { - log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", - peer.Id, peer.StreamID, pp.StreamID) - if swapped := registry.Peers.CompareAndSwap(peer.Id, pp, peer); !swapped { - return registry.Register(peer) - } - pp.Cancel() - log.Debugf("peer re-registered [%s]", peer.Id) - return nil - } - return ErrPeerAlreadyRegistered + log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", + peer.Id, peer.StreamID, pp.StreamID) + registry.Peers.Store(peer.Id, peer) + return } log.Debugf("peer registered [%s]", peer.Id) @@ -105,13 +92,22 @@ func (registry *Registry) Register(peer *Peer) error { registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) registry.metrics.Registrations.Add(context.Background(), 1) - - return nil } // Deregister Peer from the Registry (usually once it disconnects) func (registry *Registry) Deregister(peer *Peer) { - if deleted := registry.Peers.CompareAndDelete(peer.Id, peer); deleted { + registry.regMutex.Lock() + defer registry.regMutex.Unlock() + + p, loaded := registry.Peers.LoadAndDelete(peer.Id) + if loaded { + pp := p.(*Peer) + if peer.StreamID < pp.StreamID { + registry.Peers.Store(peer.Id, p) + log.Debugf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.", + peer.Id, pp.StreamID, peer.StreamID) + return + } registry.metrics.ActivePeers.Add(context.Background(), -1) log.Debugf("peer deregistered [%s]", peer.Id) registry.metrics.Deregistrations.Add(context.Background(), 1) diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go index 6b7976eb4..fb85fedda 100644 --- a/signal/peer/peer_test.go +++ b/signal/peer/peer_test.go @@ -1,18 +1,13 @@ package peer import ( - "context" - "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" - "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" ) @@ -24,16 +19,12 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) peerID := "peer" - _, cancel1 := context.WithCancel(context.Background()) - olderPeer := NewPeer(peerID, nil, cancel1) - err = r.Register(olderPeer) - require.NoError(t, err) + olderPeer := NewPeer(peerID, nil) + r.Register(olderPeer) time.Sleep(time.Nanosecond) - _, cancel2 := context.WithCancel(context.Background()) - newerPeer := NewPeer(peerID, nil, cancel2) - err = r.Register(newerPeer) - require.NoError(t, err) + newerPeer := NewPeer(peerID, nil) + r.Register(newerPeer) registered, _ := r.Get(olderPeer.Id) assert.NotNil(t, registered, "peer can't be nil") @@ -68,14 +59,10 @@ func TestRegistry_Register(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - _, cancel1 := context.WithCancel(context.Background()) - peer1 := NewPeer("test_peer_1", nil, cancel1) - _, cancel2 := context.WithCancel(context.Background()) - peer2 := NewPeer("test_peer_2", nil, cancel2) - err = r.Register(peer1) - require.NoError(t, err) - err = r.Register(peer2) - require.NoError(t, err) + peer1 := NewPeer("test_peer_1", nil) + peer2 := NewPeer("test_peer_2", nil) + r.Register(peer1) + r.Register(peer2) if _, ok := r.Get("test_peer_1"); !ok { t.Errorf("expected test_peer_1 not found in the registry") @@ -91,14 +78,10 @@ func TestRegistry_Deregister(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - _, cancel1 := context.WithCancel(context.Background()) - peer1 := NewPeer("test_peer_1", nil, cancel1) - _, cancel2 := context.WithCancel(context.Background()) - peer2 := NewPeer("test_peer_2", nil, cancel2) - err = r.Register(peer1) - require.NoError(t, err) - err = r.Register(peer2) - require.NoError(t, err) + peer1 := NewPeer("test_peer_1", nil) + peer2 := NewPeer("test_peer_2", nil) + r.Register(peer1) + r.Register(peer2) r.Deregister(peer1) @@ -111,213 +94,3 @@ func TestRegistry_Deregister(t *testing.T) { } } - -func TestRegistry_MultipleRegister_Concurrency(t *testing.T) { - - metrics, err := metrics.NewAppMetrics(otel.Meter("")) - require.NoError(t, err) - registry := NewRegistry(metrics) - - numGoroutines := 1000 - - ids := make(chan int64, numGoroutines) - - var wg sync.WaitGroup - wg.Add(numGoroutines) - peerID := "peer-concurrent" - for i := range numGoroutines { - go func(routineIndex int) { - defer wg.Done() - - _, cancel := context.WithCancel(context.Background()) - peer := NewPeer(peerID, nil, cancel) - _ = registry.Register(peer) - ids <- peer.StreamID - }(i) - } - - wg.Wait() - close(ids) - maxId := int64(0) - for id := range ids { - maxId = max(maxId, id) - } - - peer, ok := registry.Get(peerID) - require.True(t, ok, "expected peer to be registered") - require.Equal(t, maxId, peer.StreamID, "expected the highest StreamID to be registered") -} - -func Benchmark_MultipleRegister_Concurrency(b *testing.B) { - - metrics, err := metrics.NewAppMetrics(otel.Meter("")) - require.NoError(b, err) - - numGoroutines := 1000 - - var wg sync.WaitGroup - peerID := "peer-concurrent" - _, cancel := context.WithCancel(context.Background()) - b.Run("multiple-register", func(b *testing.B) { - registry := NewRegistry(metrics) - b.ResetTimer() - for j := 0; j < b.N; j++ { - wg.Add(numGoroutines) - for i := range numGoroutines { - go func(routineIndex int) { - defer wg.Done() - - peer := NewPeer(peerID, nil, cancel) - _ = registry.Register(peer) - }(i) - } - wg.Wait() - } - }) -} - -func TestRegistry_MultipleDeregister_Concurrency(t *testing.T) { - - metrics, err := metrics.NewAppMetrics(otel.Meter("")) - require.NoError(t, err) - registry := NewRegistry(metrics) - - numGoroutines := 1000 - - ids := make(chan int64, numGoroutines) - - var wg sync.WaitGroup - wg.Add(numGoroutines) - peerID := "peer-concurrent" - for i := range numGoroutines { - go func(routineIndex int) { - defer wg.Done() - - _, cancel := context.WithCancel(context.Background()) - peer := NewPeer(peerID, nil, cancel) - _ = registry.Register(peer) - ids <- peer.StreamID - registry.Deregister(peer) - }(i) - } - - wg.Wait() - close(ids) - maxId := int64(0) - for id := range ids { - maxId = max(maxId, id) - } - - _, ok := registry.Get(peerID) - require.False(t, ok, "expected peer to be deregistered") -} - -func Benchmark_MultipleDeregister_Concurrency(b *testing.B) { - - metrics, err := metrics.NewAppMetrics(otel.Meter("")) - require.NoError(b, err) - - numGoroutines := 1000 - - var wg sync.WaitGroup - peerID := "peer-concurrent" - _, cancel := context.WithCancel(context.Background()) - b.Run("register-deregister", func(b *testing.B) { - registry := NewRegistry(metrics) - b.ResetTimer() - for j := 0; j < b.N; j++ { - wg.Add(numGoroutines) - for i := range numGoroutines { - go func(routineIndex int) { - defer wg.Done() - - peer := NewPeer(peerID, nil, cancel) - _ = registry.Register(peer) - time.Sleep(time.Nanosecond) - registry.Deregister(peer) - }(i) - } - wg.Wait() - } - }) -} - -type mockConnectStreamServer struct { - grpc.ServerStream - ctx context.Context -} - -func (m *mockConnectStreamServer) Context() context.Context { - return m.ctx -} - -func (m *mockConnectStreamServer) SendHeader(md metadata.MD) error { - return nil -} - -func (m *mockConnectStreamServer) Send(msg *proto.EncryptedMessage) error { - return nil -} - -func (m *mockConnectStreamServer) Recv() (*proto.EncryptedMessage, error) { - <-m.ctx.Done() - return nil, m.ctx.Err() -} - -func TestReconnectHandling(t *testing.T) { - metrics, err := metrics.NewAppMetrics(otel.Meter("")) - require.NoError(t, err) - registry := NewRegistry(metrics) - peerID := "test-peer-reconnect" - - ctx1, cancel1 := context.WithCancel(context.Background()) - defer cancel1() - stream1 := &mockConnectStreamServer{ctx: ctx1} - peer1 := NewPeer(peerID, stream1, cancel1) - - err = registry.Register(peer1) - require.NoError(t, err, "first registration should succeed") - - p, found := registry.Get(peerID) - require.True(t, found, "peer should be found in the registry") - require.Equal(t, peer1.StreamID, p.StreamID, "StreamID of registered peer should match") - - time.Sleep(time.Nanosecond) - ctx2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - stream2 := &mockConnectStreamServer{ctx: ctx2} - peer2 := NewPeer(peerID, stream2, cancel2) - - err = registry.Register(peer2) - require.NoError(t, err, "reconnect registration should succeed") - - select { - case <-ctx1.Done(): - require.ErrorIs(t, ctx1.Err(), context.Canceled, "context of old stream should be canceled after successful reconnection") - case <-time.After(100 * time.Millisecond): - t.Fatal("context of old stream was not canceled after reconnection") - } - - p, found = registry.Get(peerID) - require.True(t, found) - require.Equal(t, peer2.StreamID, p.StreamID, "registered peer should have the new StreamID after reconnection") - - ctx3, cancel3 := context.WithCancel(context.Background()) - defer cancel3() - stream3 := &mockConnectStreamServer{ctx: ctx3} - stalePeer := NewPeer(peerID, stream3, cancel3) - stalePeer.StreamID = peer1.StreamID - - err = registry.Register(stalePeer) - require.ErrorIs(t, err, ErrPeerAlreadyRegistered, "reconnecting with an old StreamID should return an error") - - p, found = registry.Get(peerID) - require.True(t, found) - require.Equal(t, peer2.StreamID, p.StreamID, "active peer should still be the one with the latest StreamID") - - select { - case <-ctx2.Done(): - t.Fatal("context of the new stream should not be canceled after trying to register with an old StreamID") - default: - } -} diff --git a/signal/server/signal.go b/signal/server/signal.go index 47f01edae..8ae14822b 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -2,9 +2,7 @@ package server import ( "context" - "errors" "fmt" - "os" "time" log "github.com/sirupsen/logrus" @@ -17,9 +15,9 @@ import ( "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" + "github.com/netbirdio/netbird/shared/signal/proto" ) const ( @@ -29,8 +27,6 @@ const ( labelTypeNotRegistered = "not_registered" labelTypeStream = "stream" labelTypeMessage = "message" - labelTypeTimeout = "timeout" - labelTypeDisconnected = "disconnected" labelError = "error" labelErrorMissingId = "missing_id" @@ -41,12 +37,6 @@ const ( labelRegistrationStatus = "status" labelRegistrationFound = "found" labelRegistrationNotFound = "not_found" - - sendTimeout = 10 * time.Second -) - -var ( - ErrPeerRegisteredAgain = errors.New("peer registered again") ) // Server an instance of a Signal server @@ -55,10 +45,6 @@ type Server struct { proto.UnimplementedSignalExchangeServer dispatcher *dispatcher.Dispatcher metrics *metrics.AppMetrics - - successHeader metadata.MD - - sendTimeout time.Duration } // NewServer creates a new Signal server @@ -73,19 +59,10 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating dispatcher: %v", err) } - sTimeout := sendTimeout - to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT") - if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 { - log.Trace("using custom send timeout ", parsed) - sTimeout = parsed - } - s := &Server{ - dispatcher: d, - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, - successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), - sendTimeout: sTimeout, + dispatcher: d, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, } return s, nil @@ -105,8 +82,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - ctx, cancel := context.WithCancel(context.Background()) - p, err := s.RegisterPeer(stream, cancel) + p, err := s.RegisterPeer(stream) if err != nil { return err } @@ -114,7 +90,8 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) defer s.DeregisterPeer(p) // needed to confirm that the peer has been registered so that the client can proceed - err = stream.SendHeader(s.successHeader) + header := metadata.Pairs(proto.HeaderRegistered, "1") + err = stream.SendHeader(header) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader))) return err @@ -122,27 +99,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) - select { - case <-stream.Context().Done(): - log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) - return nil - case <-ctx.Done(): - return ErrPeerRegisteredAgain - } + <-stream.Context().Done() + log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) + return nil } -func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) { +func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { log.Debugf("registering new peer") - id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId) - if id == nil { + meta, hasMeta := metadata.FromIncomingContext(stream.Context()) + if !hasMeta { + s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) + return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") + } + + id, found := meta[proto.HeaderId] + if !found { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId) } - p := peer.NewPeer(id[0], stream, cancel) - if err := s.registry.Register(p); err != nil { - return nil, err - } + p := peer.NewPeer(id[0], stream) + s.registry.Register(p) err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration))) @@ -154,8 +131,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, c func (s *Server) DeregisterPeer(p *peer.Peer) { log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) - s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) s.registry.Deregister(p) + s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { @@ -168,7 +145,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM if !found { s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) + log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) // todo respond to the sender? return } @@ -176,34 +153,16 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) start := time.Now() - sendResultChan := make(chan error, 1) - go func() { - select { - case sendResultChan <- dstPeer.Stream.Send(msg): - return - case <-dstPeer.Stream.Context().Done(): - return - } - }() - - select { - case err := <-sendResultChan: - if err != nil { - log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err) - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - return - } - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(ctx, 1) - s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) - - case <-dstPeer.Stream.Context().Done(): - log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey) - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected))) - - case <-time.After(s.sendTimeout): - dstPeer.Cancel() // cancel the peer context to trigger deregistration - log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey) - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout))) + // forward the message to the target peer + if err := dstPeer.Stream.Send(msg); err != nil { + log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + // todo respond to the sender? + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + return } + + // in milliseconds + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) + s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) } diff --git a/client/grpc/dialer.go b/util/grpc/dialer.go similarity index 91% rename from client/grpc/dialer.go rename to util/grpc/dialer.go index 69e3f088c..f6d6d2f04 100644 --- a/client/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -20,9 +20,8 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - nbnet "github.com/netbirdio/netbird/client/net" - "github.com/netbirdio/netbird/util/embeddedroots" + nbnet "github.com/netbirdio/netbird/util/net" ) func WithCustomDialer() grpc.DialOption { @@ -58,7 +57,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -72,7 +71,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() conn, err := grpc.DialContext( diff --git a/util/net/conn.go b/util/net/conn.go new file mode 100644 index 000000000..26693f841 --- /dev/null +++ b/util/net/conn.go @@ -0,0 +1,31 @@ +//go:build !ios + +package net + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} diff --git a/util/net/dial.go b/util/net/dial.go new file mode 100644 index 000000000..595311492 --- /dev/null +++ b/util/net/dial.go @@ -0,0 +1,58 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/client/net/dial_ios.go b/util/net/dial_ios.go similarity index 100% rename from client/net/dial_ios.go rename to util/net/dial_ios.go diff --git a/client/net/dialer.go b/util/net/dialer.go similarity index 99% rename from client/net/dialer.go rename to util/net/dialer.go index 29bec05a7..0786c667e 100644 --- a/client/net/dialer.go +++ b/util/net/dialer.go @@ -16,5 +16,6 @@ func NewDialer() *Dialer { Dialer: &net.Dialer{}, } dialer.init() + return dialer } diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go new file mode 100644 index 000000000..1659b6220 --- /dev/null +++ b/util/net/dialer_dial.go @@ -0,0 +1,107 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" +) + +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) +} + +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +// RemoveDialerHooks removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + + if CustomRoutingDisabled() { + return d.Dialer.DialContext(ctx, network, address) + } + + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := callDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() +} diff --git a/client/net/dialer_init_android.go b/util/net/dialer_init_android.go similarity index 100% rename from client/net/dialer_init_android.go rename to util/net/dialer_init_android.go diff --git a/client/net/dialer_init_linux.go b/util/net/dialer_init_linux.go similarity index 100% rename from client/net/dialer_init_linux.go rename to util/net/dialer_init_linux.go diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go new file mode 100644 index 000000000..8c57ebbaa --- /dev/null +++ b/util/net/dialer_init_nonlinux.go @@ -0,0 +1,7 @@ +//go:build !linux + +package net + +func (d *Dialer) init() { + // implemented on Linux and Android only +} diff --git a/client/net/env.go b/util/net/env.go similarity index 94% rename from client/net/env.go rename to util/net/env.go index 8f326ca88..32425665d 100644 --- a/client/net/env.go +++ b/util/net/env.go @@ -11,7 +11,6 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" - envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) // CustomRoutingDisabled returns true if custom routing is disabled. diff --git a/util/net/env_generic.go b/util/net/env_generic.go new file mode 100644 index 000000000..6d142a838 --- /dev/null +++ b/util/net/env_generic.go @@ -0,0 +1,12 @@ +//go:build !linux || android + +package net + +func Init() { + // nothing to do on non-linux +} + +func AdvancedRouting() bool { + // non-linux currently doesn't support advanced routing + return false +} diff --git a/client/net/env_linux.go b/util/net/env_linux.go similarity index 86% rename from client/net/env_linux.go rename to util/net/env_linux.go index 82d9a74a8..3159f6462 100644 --- a/client/net/env_linux.go +++ b/util/net/env_linux.go @@ -17,7 +17,8 @@ import ( const ( // these have the same effect, skip socket env supported for backward compatibility - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) var advancedRoutingSupported bool @@ -26,7 +27,6 @@ func Init() { advancedRoutingSupported = checkAdvancedRoutingSupport() } -// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes func AdvancedRouting() bool { return advancedRoutingSupported } @@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool { } func CheckFwmarkSupport() bool { - // temporarily enable advanced routing to check if fwmarks are supported + // temporarily enable advanced routing to check fwmarks are supported old := advancedRoutingSupported advancedRoutingSupported = true defer func() { @@ -129,13 +129,3 @@ func CheckRuleOperationsSupport() bool { } return true } - -// SetVPNInterfaceName is a no-op on Linux -func SetVPNInterfaceName(name string) { - // No-op on Linux - not needed for fwmark-based routing -} - -// GetVPNInterfaceName returns empty string on Linux -func GetVPNInterfaceName() string { - return "" -} diff --git a/util/net/listen.go b/util/net/listen.go new file mode 100644 index 000000000..3ae8a9435 --- /dev/null +++ b/util/net/listen.go @@ -0,0 +1,37 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/client/net/listen_ios.go b/util/net/listen_ios.go similarity index 100% rename from client/net/listen_ios.go rename to util/net/listen_ios.go diff --git a/client/net/listener.go b/util/net/listener.go similarity index 81% rename from client/net/listener.go rename to util/net/listener.go index 4c2f53c05..f4d769f58 100644 --- a/client/net/listener.go +++ b/util/net/listener.go @@ -7,12 +7,14 @@ import ( // ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before // responding via the socket and after closing. This can be used to bypass the VPN for listeners. type ListenerConfig struct { - net.ListenConfig + *net.ListenConfig } // NewListener creates a new ListenerConfig instance. func NewListener() *ListenerConfig { - listener := &ListenerConfig{} + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } listener.init() return listener diff --git a/client/net/listener_init_android.go b/util/net/listener_init_android.go similarity index 100% rename from client/net/listener_init_android.go rename to util/net/listener_init_android.go diff --git a/client/net/listener_init_linux.go b/util/net/listener_init_linux.go similarity index 100% rename from client/net/listener_init_linux.go rename to util/net/listener_init_linux.go diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go new file mode 100644 index 000000000..80f6f7f1a --- /dev/null +++ b/util/net/listener_init_nonlinux.go @@ -0,0 +1,7 @@ +//go:build !linux + +package net + +func (l *ListenerConfig) init() { + // implemented on Linux and Android only +} diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go new file mode 100644 index 000000000..4060ab49a --- /dev/null +++ b/util/net/listener_listen.go @@ -0,0 +1,205 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. +type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc + listenerAddressRemoveHooksMutex sync.RWMutex + listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) +} + +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { + listenerAddressRemoveHooksMutex.Lock() + defer listenerAddressRemoveHooksMutex.Unlock() + listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) +} + +// RemoveListenerHooks removes all listener hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil + + listenerAddressRemoveHooksMutex.Lock() + defer listenerAddressRemoveHooksMutex.Unlock() + listenerAddressRemoveHooks = nil +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.UDPConn) +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *PacketConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) + + listenerAddressRemoveHooksMutex.RLock() + defer listenerAddressRemoveHooksMutex.RUnlock() + + for _, hook := range listenerAddressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + + +// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality +func WrapPacketConn(conn net.PacketConn) *PacketConn { + return &PacketConn{ + PacketConn: conn, + ID: GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + return + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + return + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(id, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } +} + +func closeConn(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(id, conn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + return err +} diff --git a/client/net/listener_listen_ios.go b/util/net/listener_listen_ios.go similarity index 100% rename from client/net/listener_listen_ios.go rename to util/net/listener_listen_ios.go diff --git a/client/net/net.go b/util/net/net.go similarity index 81% rename from client/net/net.go rename to util/net/net.go index a97de9d59..fdcf4ee6a 100644 --- a/client/net/net.go +++ b/util/net/net.go @@ -5,6 +5,8 @@ import ( "math/big" "net" "net/netip" + + "github.com/google/uuid" ) const ( @@ -42,6 +44,18 @@ func IsDataPlaneMark(fwmark uint32) bool { return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper } +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +type AddHookFunc func(connID ConnectionID, IP net.IP) error +type RemoveHookFunc func(connID ConnectionID) error + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} + func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP addr := network.Addr().AsSlice() diff --git a/client/net/net_linux.go b/util/net/net_linux.go similarity index 100% rename from client/net/net_linux.go rename to util/net/net_linux.go diff --git a/client/net/net_test.go b/util/net/net_test.go similarity index 100% rename from client/net/net_test.go rename to util/net/net_test.go diff --git a/client/net/protectsocket_android.go b/util/net/protectsocket_android.go similarity index 89% rename from client/net/protectsocket_android.go rename to util/net/protectsocket_android.go index 00071461d..febed8a1e 100644 --- a/client/net/protectsocket_android.go +++ b/util/net/protectsocket_android.go @@ -4,8 +4,6 @@ import ( "fmt" "sync" "syscall" - - "github.com/netbirdio/netbird/client/iface/netstack" ) var ( @@ -21,9 +19,6 @@ func SetAndroidProtectSocketFn(fn func(fd int32) bool) { // ControlProtectSocket is a Control function that sets the fwmark on the socket func ControlProtectSocket(_, _ string, c syscall.RawConn) error { - if netstack.IsEnabled() { - return nil - } var aErr error err := c.Control(func(fd uintptr) { androidProtectSocketLock.Lock()