diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 0013833c4..ba36c013b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go @@ -382,6 +382,32 @@ jobs: store: [ 'sqlite', 'postgres' ] runs-on: ubuntu-22.04 steps: + - name: Create Docker network + run: docker network create promnet + + - name: Start Prometheus Pushgateway + run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway + + - name: Start Prometheus (for Pushgateway forwarding) + run: | + echo ' + global: + scrape_interval: 15s + scrape_configs: + - job_name: "pushgateway" + static_configs: + - targets: ["pushgateway:9091"] + remote_write: + - url: ${{ secrets.GRAFANA_URL }} + basic_auth: + username: ${{ secrets.GRAFANA_USER }} + password: ${{ secrets.GRAFANA_API_KEY }} + ' > prometheus.yml + + docker run -d --name prometheus --network promnet \ + -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ + -p 9090:9090 \ + prom/prometheus - name: Install Go uses: actions/setup-go@v5 with: @@ -428,9 +454,10 @@ jobs: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \ CI=true \ + GIT_BRANCH=${{ github.ref_name }} \ go test -tags devcert -run=^$ -bench=. \ - -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/... ./shared/management/... + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ + -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http) api_benchmark: name: "Management / Benchmark (API)" @@ -521,7 +548,7 @@ jobs: -run=^$ \ -bench=. \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ - -timeout 20m ./management/... ./shared/management/... + -timeout 20m ./management/server/http/... api_integration_test: name: "Management / Integration" @@ -571,4 +598,4 @@ jobs: CI=true \ go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/... ./shared/management/... \ No newline at end of file + -timeout 20m ./management/server/http/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index d9ff0a84b..2083c0721 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -63,7 +63,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV - name: test run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7be52259b..e9741f541 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.22" + SIGN_PIPE_VER: "v0.0.23" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" diff --git a/README.md b/README.md index ea7655869..2c5ee2ab6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +


@@ -52,7 +53,7 @@ ### Open Source Network Security in a Single Platform -centralized-network-management 1 +https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) diff --git a/client/Dockerfile b/client/Dockerfile index e19a09909..b2f627409 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,7 +18,7 @@ ENV \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_LOGIN_TIMEOUT="5" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/android/client.go b/client/android/client.go index c05246569..d2d0c37f6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,6 +4,7 @@ package android import ( "context" + "os" "slices" "sync" @@ -18,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile @@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } + } +} diff --git a/client/android/env_list.go b/client/android/env_list.go new file mode 100644 index 000000000..04122300a --- /dev/null +++ b/client/android/env_list.go @@ -0,0 +1,32 @@ +package android + +import "github.com/netbirdio/netbird/client/internal/peer" + +var ( + // EnvKeyNBForceRelay Exported for Android java client + EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay +) + +// EnvList wraps a Go map for export to Java +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} diff --git a/client/android/login.go b/client/android/login.go index d8ac645e2..0df78dbc3 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -33,6 +33,7 @@ type ErrListener interface { // the backend want to show an url for the user type URLOpener interface { Open(string) + OnLoginSuccess() } // Auth can register or login new client @@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error { err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) + + if err == nil { + go urlOpener.OnLoginSuccess() + } + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return nil } diff --git a/client/cmd/down.go b/client/cmd/down.go index 3ce51c678..17c152d22 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -27,7 +27,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/cmd/login.go b/client/cmd/login.go index 92de6abdb..3ac211805 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, } // update host's static platform and system information - system.UpdateStaticInfo() + system.UpdateStaticInfoAsync() configFilePath, err := activeProf.FilePath() if err != nil { diff --git a/client/cmd/root.go b/client/cmd/root.go index 290cae258..9f2eb109c 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -228,7 +228,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*3) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() return grpc.DialContext( diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 50fb35d5e..0545ce6b7 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error { log.Info("starting NetBird service") //nolint // Collect static system and platform information - system.UpdateStaticInfo() + system.UpdateStaticInfoAsync() // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer() diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index e45443751..99ccb1539 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -9,29 +9,26 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" + client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" + mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" - - "github.com/netbirdio/netbird/util" - - "google.golang.org/grpc" - - "github.com/netbirdio/management-integrations/integrations" - - clientProto "github.com/netbirdio/netbird/client/proto" - client "github.com/netbirdio/netbird/client/server" - mgmt "github.com/netbirdio/netbird/management/server" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto" sig "github.com/netbirdio/netbird/signal/server" + "github.com/netbirdio/netbird/util" ) func startTestingServices(t *testing.T) string { @@ -90,15 +87,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { return nil, nil } - iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - require.NoError(t, err) ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - settingsMockManager := settings.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl) + peersmanager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) + + iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() settingsMockManager.EXPECT(). diff --git a/client/cmd/up.go b/client/cmd/up.go index 61b442cea..1a553711d 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{}) + status, err := client.Status(ctx, &proto.StatusRequest{ + WaitForReady: func() *bool { b := true; return &b }(), + }) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index de83f9d96..0bfc7a37c 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}, 1) + run := make(chan struct{}) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7b90000a8..ed8a7403b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -12,7 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index d8e8857d4..343f5e05e 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // constants needed to manage and create iptable rules diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index e9eeff863..3490c5dad 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,7 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func isIptablesSupported() bool { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 52979d257..9ff5b8c92 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aa4098821..0c091da96 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -22,7 +22,7 @@ import ( nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/util/grpc/dialer.go b/client/grpc/dialer.go similarity index 91% rename from util/grpc/dialer.go rename to client/grpc/dialer.go index f6d6d2f04..69e3f088c 100644 --- a/util/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -20,8 +20,9 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" ) func WithCustomDialer() grpc.DialOption { @@ -57,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -71,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { })) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() conn, err := grpc.DialContext( diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go index 89bddf12c..32b07c330 100644 --- a/client/iface/bind/control.go +++ b/client/iface/bind/control.go @@ -3,7 +3,7 @@ package bind import ( wireguard "golang.zx2c4.com/wireguard/conn" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index f23be406e..577c7c0c4 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -8,15 +8,16 @@ import ( "runtime" "sync" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type RecvMessage struct { @@ -44,7 +45,7 @@ type ICEBind struct { RecvChan chan RecvMessage transportNet transport.Net - filterFn FilterFn + filterFn udpmux.FilterFn endpoints map[netip.Addr]net.Conn endpointsMu sync.Mutex // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a @@ -54,13 +55,13 @@ type ICEBind struct { closed bool muUDPMux sync.Mutex - udpMux *UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault address wgaddr.Address mtu uint16 activityRecorder *ActivityRecorder } -func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, @@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder { } // GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { +func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() if s.udpMux == nil { @@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go deleted file mode 100644 index db0249d11..000000000 --- a/client/iface/bind/udp_mux_ios.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build ios - -package bind - -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { - // iOS doesn't support nbnet hooks, so this is a no-op -} diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 171458e38..f744e0127 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/monotime" - nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) { if err != nil { return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) } + + // If sec is 0 (Unix epoch), return zero time instead + // This indicates no handshake has occurred + if sec == 0 { + return time.Time{}, nil + } + return time.Unix(sec, 0), nil } @@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) { } func getFwmark() int { - if nbnet.AdvancedRouting() { + if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { return nbnet.ControlPlaneMark } return 0 diff --git a/client/iface/device.go b/client/iface/device.go index ca6dda2c2..921f0ea98 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -7,14 +7,14 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index fe3b9f82e..a731684cc 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type WGTunDevice struct { name string device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string } return t.configurer, nil } -func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index cce9d42df..390efe088 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -26,7 +27,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 168985b5e..96e4c8bcf 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -28,7 +29,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 00a72bcc6..cdac43a53 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -12,11 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/sharedsock" - nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -31,9 +31,9 @@ type TunKernelDevice struct { link *wgLink udpMuxConn net.PacketConn - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault - filterFn bind.FilterFn + filterFn udpmux.FilterFn } func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { @@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { return configurer, nil } -func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } - var udpConn net.PacketConn = rawSock - if !nbnet.AdvancedRouting() { - udpConn = nbnet.WrapPacketConn(rawSock) - } - - bindParams := bind.UniversalUDPMuxParams{ - UDPConn: udpConn, + bindParams := udpmux.UniversalUDPMuxParams{ + UDPConn: nbnet.WrapPacketConn(rawSock), Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, MTU: t.mtu, } - mux := bind.NewUniversalUDPMuxDefault(bindParams) + mux := udpmux.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) t.udpMuxConn = rawSock t.udpMux = mux diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f41331ff7..a6ef47027 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -10,8 +10,9 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type TunNetstackDevice struct { @@ -26,7 +27,7 @@ type TunNetstackDevice struct { device *device.Device filteredDevice *FilteredDevice nsTun *nbnetstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer net *netstack.Net @@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 8d30112ae..4cdd70a32 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -25,7 +26,7 @@ type USPDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index de258868f..f1023bc0a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type TunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 39b5c28ae..4649b8b97 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -5,14 +5,14 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/iface.go b/client/iface/iface.go index 9a42223a1..609572561 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -16,9 +16,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -61,7 +61,7 @@ type WGIFaceOpts struct { MTU uint16 MobileArgs *device.MobileIFaceArguments TransportNet transport.Net - FilterFn bind.FilterFn + FilterFn udpmux.FilterFn DisableDNS bool } @@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface { // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/bind/udp_muxed_conn.go b/client/iface/udpmux/conn.go similarity index 95% rename from client/iface/bind/udp_muxed_conn.go rename to client/iface/udpmux/conn.go index 7cacf1c31..3aa40caeb 100644 --- a/client/iface/bind/udp_muxed_conn.go +++ b/client/iface/udpmux/conn.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements @@ -16,11 +16,12 @@ import ( ) type udpMuxedConnParams struct { - Mux *UDPMuxDefault - AddrPool *sync.Pool - Key string - LocalAddr net.Addr - Logger logging.LeveledLogger + Mux *SingleSocketUDPMux + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger + CandidateID string } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag @@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error { return err } +func (c *udpMuxedConn) GetCandidateID() string { + return c.params.CandidateID +} + func (c *udpMuxedConn) isClosed() bool { select { case <-c.closedChan: diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go new file mode 100644 index 000000000..27e5e43bc --- /dev/null +++ b/client/iface/udpmux/doc.go @@ -0,0 +1,64 @@ +// Package udpmux provides a custom implementation of a UDP multiplexer +// that allows multiple logical ICE connections to share a single underlying +// UDP socket. This is based on Pion's ICE library, with modifications for +// NetBird's requirements. +// +// # Background +// +// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity +// Establishment) is responsible for discovering candidate network paths +// and maintaining connectivity between peers. Each ICE connection +// normally requires a dedicated UDP socket. However, using one socket +// per candidate can be inefficient and difficult to manage. +// +// This package introduces SingleSocketUDPMux, which allows multiple ICE +// candidate connections (muxed connections) to share a single UDP socket. +// It handles demultiplexing of packets based on ICE ufrag values, STUN +// attributes, and candidate IDs. +// +// # Usage +// +// The typical flow is: +// +// 1. Create a UDP socket (net.PacketConn). +// 2. Construct Params with the socket and optional logger/net stack. +// 3. Call NewSingleSocketUDPMux(params). +// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID) +// to obtain a logical PacketConn. +// 5. Use the returned PacketConn just like a normal UDP connection. +// +// # STUN Message Routing Logic +// +// When a STUN packet arrives, the mux decides which connection should +// receive it using this routing logic: +// +// Primary Routing: Candidate Pair ID +// - Extract the candidate pair ID from the STUN message using +// ice.CandidatePairIDFromSTUN(msg) +// - The target candidate is the locally generated candidate that +// corresponds to the connection that should handle this STUN message +// - If found, use the target candidate ID to lookup the specific +// connection in candidateConnMap +// - Route the message directly to that connection +// +// Fallback Routing: Broadcasting +// When candidate pair ID is not available or lookup fails: +// - Collect connections from addressMap based on source address +// - Find connection using username attribute (ufrag) from STUN message +// - Remove duplicate connections from the list +// - Send the STUN message to all collected connections +// +// # Peer Reflexive Candidate Discovery +// +// When a remote peer sends a STUN message from an unknown source address +// (from a candidate that has not been exchanged via signal), the ICE +// library will: +// - Generate a new peer reflexive candidate for this source address +// - Extract or assign a candidate ID based on the STUN message attributes +// - Create a mapping between the new peer reflexive candidate ID and +// the appropriate local connection +// +// This discovery mechanism ensures that STUN messages from newly discovered +// peer reflexive candidates can be properly routed to the correct local +// connection without requiring fallback broadcasting. +package udpmux diff --git a/client/iface/bind/udp_mux.go b/client/iface/udpmux/mux.go similarity index 65% rename from client/iface/bind/udp_mux.go rename to client/iface/udpmux/mux.go index 29e5d7937..319724926 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/udpmux/mux.go @@ -1,4 +1,4 @@ -package bind +package udpmux import ( "fmt" @@ -8,9 +8,9 @@ import ( "strings" "sync" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/logging" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" @@ -22,9 +22,9 @@ import ( const receiveMTU = 8192 -// UDPMuxDefault is an implementation of the interface -type UDPMuxDefault struct { - params UDPMuxParams +// SingleSocketUDPMux is an implementation of the interface +type SingleSocketUDPMux struct { + params Params closedChan chan struct{} closeOnce sync.Once @@ -32,6 +32,9 @@ type UDPMuxDefault struct { // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn + // candidateConnMap maps local candidate IDs to their corresponding connection. + candidateConnMap map[string]*udpMuxedConn + addressMapMu sync.RWMutex addressMap map[string][]*udpMuxedConn @@ -46,8 +49,8 @@ type UDPMuxDefault struct { const maxAddrSize = 512 -// UDPMuxParams are parameters for UDPMux. -type UDPMuxParams struct { +// Params are parameters for UDPMux. +type Params struct { Logger logging.LeveledLogger UDPConn net.PacketConn @@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool { return true } -// NewUDPMuxDefault creates an implementation of UDPMux -func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { +// NewSingleSocketUDPMux creates an implementation of UDPMux +func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { if params.Logger == nil { params.Logger = getLogger() } - mux := &UDPMuxDefault{ - addressMap: map[string][]*udpMuxedConn{}, - params: params, - connsIPv4: make(map[string]*udpMuxedConn), - connsIPv6: make(map[string]*udpMuxedConn), - closedChan: make(chan struct{}, 1), + mux := &SingleSocketUDPMux{ + addressMap: map[string][]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + candidateConnMap: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address @@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return mux } -func (m *UDPMuxDefault) updateLocalAddresses() { +func (m *SingleSocketUDPMux) updateLocalAddresses() { var localAddrsForUnspecified []net.Addr if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) } else if ok && addr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + // with SingleSocketUDPMux, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { @@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() { m.mu.Unlock() } -// LocalAddr returns the listening address of this UDPMuxDefault -func (m *UDPMuxDefault) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this SingleSocketUDPMux +func (m *SingleSocketUDPMux) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on -func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { +func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { m.updateLocalAddresses() m.mu.Lock() @@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network address // creates the connection if an existing one can't be found -func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { +func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { // don't check addr for mux using unspecified address m.mu.Lock() lenLocalAddrs := len(m.localAddrsForUnspecified) @@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return conn, nil } - c := m.createMuxedConn(ufrag) + c := m.createMuxedConn(ufrag, candidateID) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() + m.candidateConnMap[candidateID] = c + if isIPv6 { m.connsIPv6[ufrag] = c } else { @@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er } // RemoveConnByUfrag stops and removes the muxed packet connection -func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { +func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock @@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } m.mu.Unlock() @@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { } // IsClosed returns true if the mux had been closed -func (m *UDPMuxDefault) IsClosed() bool { +func (m *SingleSocketUDPMux) IsClosed() bool { select { case <-m.closedChan: return true @@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool { } // Close the mux, no further connections could be created -func (m *UDPMuxDefault) Close() error { +func (m *SingleSocketUDPMux) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() @@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error { return err } -func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { +func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } @@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } -func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { +func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ - Mux: m, - Key: key, - AddrPool: m.pool, - LocalAddr: m.LocalAddr(), - Logger: m.params.Logger, + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + CandidateID: candidateID, }) return c } // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library -func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { - +func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { remoteAddr, ok := addr.(*net.UDPAddr) if !ok { return fmt.Errorf("underlying PacketConn did not return a UDPAddr") } - // If we have already seen this address dispatch to the appropriate destination - // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one - // muxed connection - one for the SRFLX candidate and the other one for the HOST one. - // We will then forward STUN packets to each of these connections. - m.addressMapMu.RLock() + // Try to route to specific candidate connection first + if conn := m.findCandidateConnection(msg); conn != nil { + return conn.writePacket(msg.Raw, remoteAddr) + } + + // Fallback: route to all possible connections + return m.forwardToAllConnections(msg, addr, remoteAddr) +} + +// findCandidateConnection attempts to find the specific connection for a STUN message +func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn { + candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg) + if err != nil { + return nil + } else if !ok { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()] + if !exists { + return nil + } + return conn +} + +// forwardToAllConnections forwards STUN message to all relevant connections +func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error { var destinationConnList []*udpMuxedConn + + // Add connections from address map + m.addressMapMu.RLock() if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } m.addressMapMu.RUnlock() - var isIPv6 bool - if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { - isIPv6 = true + if conn, ok := m.findConnectionByUsername(msg, addr); ok { + // If we have already seen this address dispatch to the appropriate destination + // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one + // muxed connection - one for the SRFLX candidate and the other one for the HOST one. + // We will then forward STUN packets to each of these connections. + if !m.connectionExists(conn, destinationConnList) { + destinationConnList = append(destinationConnList, conn) + } } - // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. - // However, we can take a username attribute from the STUN message which contains ufrag. - // We can use ufrag to identify the destination conn to route packet to. - attr, stunAttrErr := msg.Get(stun.AttrUsername) - if stunAttrErr == nil { - ufrag := strings.Split(string(attr), ":")[0] - - m.mu.Lock() - destinationConn := m.connsIPv4[ufrag] - if isIPv6 { - destinationConn = m.connsIPv6[ufrag] - } - - if destinationConn != nil { - exists := false - for _, conn := range destinationConnList { - if conn.params.Key == destinationConn.params.Key { - exists = true - break - } - } - if !exists { - destinationConnList = append(destinationConnList, destinationConn) - } - } - m.mu.Unlock() - } - - // Forward STUN packets to each destination connections even thought the STUN packet might not belong there. - // It will be discarded by the further ICE candidate logic if so. + // Forward to all found connections for _, conn := range destinationConnList { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { log.Errorf("could not write packet: %v", err) } } - return nil } -func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { +// findConnectionByUsername finds connection using username attribute from STUN message +func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) { + attr, err := msg.Get(stun.AttrUsername) + if err != nil { + return nil, false + } + + ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := isIPv6Address(addr) + + m.mu.Lock() + defer m.mu.Unlock() + + return m.getConn(ufrag, isIPv6) +} + +// connectionExists checks if a connection already exists in the list +func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool { + for _, conn := range conns { + if conn.params.Key == target.params.Key { + return true + } + } + return false +} + +func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { @@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o return } +func isIPv6Address(addr net.Addr) bool { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + return udpAddr.IP.To4() == nil + } + return false +} + type bufferHolder struct { buf []byte } diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/udpmux/mux_generic.go similarity index 76% rename from client/iface/bind/udp_mux_generic.go rename to client/iface/udpmux/mux_generic.go index 63f786d2b..29fc2d834 100644 --- a/client/iface/bind/udp_mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -1,12 +1,12 @@ //go:build !ios -package bind +package udpmux import ( - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { conn.RemoveAddress(addr) diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go new file mode 100644 index 000000000..4cf211d8f --- /dev/null +++ b/client/iface/udpmux/mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package udpmux + +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/udpmux/universal.go similarity index 96% rename from client/iface/bind/udp_mux_universal.go rename to client/iface/udpmux/universal.go index b06da6712..43bfedaaa 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/udpmux/universal.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. @@ -15,7 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/pion/logging" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/netbirdio/netbird/client/iface/bufsize" @@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { - *UDPMuxDefault + *SingleSocketUDPMux params UniversalUDPMuxParams // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents @@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - udpMuxParams := UDPMuxParams{ + udpMuxParams := Params{ Logger: params.Logger, UDPConn: m.params.UDPConn, Net: m.params.Net, } - m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) return m } @@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { - return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { + return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) } // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. @@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A } return nil } - return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) + return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index fcdc0189d..b899f1694 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/internal/connect.go b/client/internal/connect.go index a3cc7be1d..2bfa263fc 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -34,7 +34,7 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) @@ -275,7 +275,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engineMutex.Unlock() - if err := c.engine.Start(); err != nil { + if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } @@ -284,10 +284,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan state.Set(StatusConnected) if runningChan != nil { - select { - case runningChan <- struct{}{}: - default: - } + close(runningChan) + runningChan = nil } <-engineCtx.Done() diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go new file mode 100644 index 000000000..cb651f1e5 --- /dev/null +++ b/client/internal/dns/config/domains.go @@ -0,0 +1,201 @@ +package config + +import ( + "errors" + "fmt" + "net" + "net/netip" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/shared/management/domain" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +var ( + ErrEmptyURL = errors.New("empty URL") + ErrEmptyHost = errors.New("empty host") + ErrIPNotAllowed = errors.New("IP address not allowed") +) + +// ServerDomains represents the management server domains extracted from NetBird configuration +type ServerDomains struct { + Signal domain.Domain + Relay []domain.Domain + Flow domain.Domain + Stuns []domain.Domain + Turns []domain.Domain +} + +// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration +func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { + if config == nil { + return ServerDomains{} + } + + domains := ServerDomains{} + + domains.Signal = extractSignalDomain(config) + domains.Relay = extractRelayDomains(config) + domains.Flow = extractFlowDomain(config) + domains.Stuns = extractStunDomains(config) + domains.Turns = extractTurnDomains(config) + + return domains +} + +// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses +func ExtractValidDomain(rawURL string) (domain.Domain, error) { + if rawURL == "" { + return "", ErrEmptyURL + } + + parsedURL, err := url.Parse(rawURL) + if err == nil { + if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" { + return domain, err + } + } + + return extractFromRawString(rawURL) +} + +// extractFromParsedURL handles domain extraction from successfully parsed URLs +func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) { + if parsedURL.Hostname() != "" { + return extractDomainFromHost(parsedURL.Hostname()) + } + + if parsedURL.Opaque == "" || parsedURL.Scheme == "" { + return "", nil + } + + // Handle URLs with opaque content (e.g., stun:host:port) + if strings.Contains(parsedURL.Scheme, ".") { + // This is likely "domain.com:port" being parsed as scheme:opaque + reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque + if host, _, err := net.SplitHostPort(reconstructed); err == nil { + return extractDomainFromHost(host) + } + return extractDomainFromHost(parsedURL.Scheme) + } + + // Valid scheme with opaque content (e.g., stun:host:port) + host := parsedURL.Opaque + if queryIndex := strings.Index(host, "?"); queryIndex > 0 { + host = host[:queryIndex] + } + + if hostOnly, _, err := net.SplitHostPort(host); err == nil { + return extractDomainFromHost(hostOnly) + } + + return extractDomainFromHost(host) +} + +// extractFromRawString handles domain extraction when URL parsing fails or returns no results +func extractFromRawString(rawURL string) (domain.Domain, error) { + if host, _, err := net.SplitHostPort(rawURL); err == nil { + return extractDomainFromHost(host) + } + + return extractDomainFromHost(rawURL) +} + +// extractDomainFromHost extracts domain from a host string, filtering out IP addresses +func extractDomainFromHost(host string) (domain.Domain, error) { + if host == "" { + return "", ErrEmptyHost + } + + if _, err := netip.ParseAddr(host); err == nil { + return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) + } + + d, err := domain.FromString(host) + if err != nil { + return "", fmt.Errorf("invalid domain: %v", err) + } + + return d, nil +} + +// extractSingleDomain extracts a single domain from a URL with error logging +func extractSingleDomain(url, serviceType string) domain.Domain { + if url == "" { + return "" + } + + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + return "" + } + + return d +} + +// extractMultipleDomains extracts multiple domains from URLs with error logging +func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { + var domains []domain.Domain + for _, url := range urls { + if url == "" { + continue + } + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + continue + } + domains = append(domains, d) + } + return domains +} + +// extractSignalDomain extracts the signal domain from NetBird configuration. +func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Signal != nil { + return extractSingleDomain(config.Signal.Uri, "signal") + } + return "" +} + +// extractRelayDomains extracts relay server domains from NetBird configuration. +func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Relay != nil { + return extractMultipleDomains(config.Relay.Urls, "relay") + } + return nil +} + +// extractFlowDomain extracts the traffic flow domain from NetBird configuration. +func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Flow != nil { + return extractSingleDomain(config.Flow.Url, "flow") + } + return "" +} + +// extractStunDomains extracts STUN server domains from NetBird configuration. +func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, stun := range config.Stuns { + if stun != nil && stun.Uri != "" { + urls = append(urls, stun.Uri) + } + } + return extractMultipleDomains(urls, "STUN") +} + +// extractTurnDomains extracts TURN server domains from NetBird configuration. +func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, turn := range config.Turns { + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { + urls = append(urls, turn.HostConfig.Uri) + } + } + return extractMultipleDomains(urls, "TURN") +} diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go new file mode 100644 index 000000000..5eae3a541 --- /dev/null +++ b/client/internal/dns/config/domains_test.go @@ -0,0 +1,213 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractValidDomain(t *testing.T) { + tests := []struct { + name string + url string + expected string + expectError bool + }{ + { + name: "HTTPS URL with port", + url: "https://api.netbird.io:443", + expected: "api.netbird.io", + }, + { + name: "HTTP URL without port", + url: "http://signal.example.com", + expected: "signal.example.com", + }, + { + name: "Host with port (no scheme)", + url: "signal.netbird.io:443", + expected: "signal.netbird.io", + }, + { + name: "STUN URL", + url: "stun:stun.netbird.io:443", + expected: "stun.netbird.io", + }, + { + name: "STUN URL with different port", + url: "stun:stun.netbird.io:5555", + expected: "stun.netbird.io", + }, + { + name: "TURNS URL with query params", + url: "turns:turn.netbird.io:443?transport=tcp", + expected: "turn.netbird.io", + }, + { + name: "TURN URL", + url: "turn:turn.example.com:3478", + expected: "turn.example.com", + }, + { + name: "REL URL", + url: "rel://relay.example.com:443", + expected: "relay.example.com", + }, + { + name: "RELS URL", + url: "rels://relay.netbird.io:443", + expected: "relay.netbird.io", + }, + { + name: "Raw hostname", + url: "example.org", + expected: "example.org", + }, + { + name: "IP address should be rejected", + url: "192.168.1.1", + expectError: true, + }, + { + name: "IP address with port should be rejected", + url: "192.168.1.1:443", + expectError: true, + }, + { + name: "IPv6 address should be rejected", + url: "2001:db8::1", + expectError: true, + }, + { + name: "HTTP URL with IPv4 should be rejected", + url: "http://192.168.1.1:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv4 should be rejected", + url: "https://10.0.0.1:443", + expectError: true, + }, + { + name: "STUN URL with IPv4 should be rejected", + url: "stun:192.168.1.1:3478", + expectError: true, + }, + { + name: "TURN URL with IPv4 should be rejected", + url: "turn:10.0.0.1:3478", + expectError: true, + }, + { + name: "TURNS URL with IPv4 should be rejected", + url: "turns:172.16.0.1:5349", + expectError: true, + }, + { + name: "HTTP URL with IPv6 should be rejected", + url: "http://[2001:db8::1]:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv6 should be rejected", + url: "https://[::1]:443", + expectError: true, + }, + { + name: "STUN URL with IPv6 should be rejected", + url: "stun:[2001:db8::1]:3478", + expectError: true, + }, + { + name: "IPv6 with port should be rejected", + url: "[2001:db8::1]:443", + expectError: true, + }, + { + name: "Localhost IPv4 should be rejected", + url: "127.0.0.1:8080", + expectError: true, + }, + { + name: "Localhost IPv6 should be rejected", + url: "[::1]:443", + expectError: true, + }, + { + name: "REL URL with IPv4 should be rejected", + url: "rel://192.168.1.1:443", + expectError: true, + }, + { + name: "RELS URL with IPv4 should be rejected", + url: "rels://10.0.0.1:443", + expectError: true, + }, + { + name: "Empty URL", + url: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ExtractValidDomain(tt.url) + + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Unexpected error for URL: %s", tt.url) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url) + } + }) + } +} + +func TestExtractDomainFromHost(t *testing.T) { + tests := []struct { + name string + host string + expected string + expectError bool + }{ + { + name: "Valid domain", + host: "example.com", + expected: "example.com", + }, + { + name: "Subdomain", + host: "api.example.com", + expected: "api.example.com", + }, + { + name: "IPv4 address", + host: "192.168.1.1", + expectError: true, + }, + { + name: "IPv6 address", + host: "2001:db8::1", + expectError: true, + }, + { + name: "Empty host", + host: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractDomainFromHost(tt.host) + + if tt.expectError { + assert.Error(t, err, "Expected error for host: %s", tt.host) + } else { + assert.NoError(t, err, "Unexpected error for host: %s", tt.host) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host) + } + }) + } +} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 439bcbb3c..2e54bffd9 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,11 +11,12 @@ import ( ) const ( - PriorityLocal = 100 - PriorityDNSRoute = 75 - PriorityUpstream = 50 - PriorityDefault = 1 - PriorityFallback = -100 + PriorityMgmtCache = 150 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants to continue, try next handler if chainWriter.shouldContinue { - log.Tracef("handler requested continue to next handler for domain=%s", qname) + // Only log continue for non-management cache handlers to reduce noise + if entry.Priority != PriorityMgmtCache { + log.Tracef("handler requested continue to next handler for domain=%s", qname) + } continue } return diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 852dfef48..b06ba73ab 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -166,9 +166,10 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { - err := s.recordSystemDNSSettings(true) - log.Errorf("Unable to get system DNS configuration") - return err + if err := s.recordSystemDNSSettings(true); err != nil { + log.Errorf("Unable to get system DNS configuration") + return fmt.Errorf("recordSystemDNSSettings(): %w", err) + } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b776fbbe3..bac7875ec 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool { // String returns a string representation of the local resolver func (d *Resolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.records)) + return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) } func (d *Resolver) Stop() {} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go new file mode 100644 index 000000000..290395473 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt.go @@ -0,0 +1,360 @@ +package mgmt + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/shared/management/domain" +) + +const dnsTimeout = 5 * time.Second + +// Resolver caches critical NetBird infrastructure domains +type Resolver struct { + records map[dns.Question][]dns.RR + mgmtDomain *domain.Domain + serverDomains *dnsconfig.ServerDomains + mutex sync.RWMutex +} + +// NewResolver creates a new management domains cache resolver. +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +// String returns a string representation of the resolver. +func (m *Resolver) String() string { + return "MgmtCacheResolver" +} + +// ServeDNS implements dns.Handler interface. +func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + m.continueToNext(w, r) + return + } + + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { + m.continueToNext(w, r) + return + } + + m.mutex.RLock() + records, found := m.records[question] + m.mutex.RUnlock() + + if !found { + m.continueToNext(w, r) + return + } + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + + resp.Answer = append(resp.Answer, records...) + + log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write response: %v", err) + } +} + +// MatchSubdomains returns false since this resolver only handles exact domain matches +// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. +func (m *Resolver) MatchSubdomains() bool { + return false +} + +// continueToNext signals the handler chain to continue to the next handler. +func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write continue signal: %v", err) + } +} + +// AddDomain manually adds a domain to cache by resolving it. +func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + if err != nil { + return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + } + + var aRecords, aaaaRecords []dns.RR + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) + } + } + + m.mutex.Lock() + + if len(aRecords) > 0 { + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + m.records[aQuestion] = aRecords + } + + if len(aaaaRecords) > 0 { + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + m.records[aaaaQuestion] = aaaaRecords + } + + m.mutex.Unlock() + + log.Debugf("added domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + + return nil +} + +// PopulateFromConfig extracts and caches domains from the client configuration. +func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { + if mgmtURL == nil { + return nil + } + + d, err := dnsconfig.ExtractValidDomain(mgmtURL.String()) + if err != nil { + return fmt.Errorf("extract domain from URL: %w", err) + } + + m.mutex.Lock() + m.mgmtDomain = &d + m.mutex.Unlock() + + if err := m.AddDomain(ctx, d); err != nil { + return fmt.Errorf("add domain: %w", err) + } + + return nil +} + +// RemoveDomain removes a domain from the cache. +func (m *Resolver) RemoveDomain(d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + m.mutex.Lock() + defer m.mutex.Unlock() + + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + delete(m.records, aQuestion) + + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + delete(m.records, aaaaQuestion) + + log.Debugf("removed domain=%s from cache", d.SafeString()) + return nil +} + +// GetCachedDomains returns a list of all cached domains. +func (m *Resolver) GetCachedDomains() domain.List { + m.mutex.RLock() + defer m.mutex.RUnlock() + + domainSet := make(map[domain.Domain]struct{}) + for question := range m.records { + domainName := strings.TrimSuffix(question.Name, ".") + domainSet[domain.Domain(domainName)] = struct{}{} + } + + domains := make(domain.List, 0, len(domainSet)) + for d := range domainSet { + domains = append(domains, d) + } + + return domains +} + +// UpdateFromServerDomains updates the cache with server domains from network configuration. +// It merges new domains with existing ones, replacing entire domain types when updated. +// Empty updates are ignored to prevent clearing infrastructure domains during partial updates. +func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) { + newDomains := m.extractDomainsFromServerDomains(serverDomains) + var removedDomains domain.List + + if len(newDomains) > 0 { + m.mutex.Lock() + if m.serverDomains == nil { + m.serverDomains = &dnsconfig.ServerDomains{} + } + updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains) + m.serverDomains = &updatedServerDomains + m.mutex.Unlock() + + allDomains := m.extractDomainsFromServerDomains(updatedServerDomains) + currentDomains := m.GetCachedDomains() + removedDomains = m.removeStaleDomains(currentDomains, allDomains) + } + + m.addNewDomains(ctx, newDomains) + + return removedDomains, nil +} + +// removeStaleDomains removes cached domains not present in the target domain list. +// Management domains are preserved and never removed during server domain updates. +func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List { + var removedDomains domain.List + + for _, currentDomain := range currentDomains { + if m.isDomainInList(currentDomain, newDomains) { + continue + } + + if m.isManagementDomain(currentDomain) { + continue + } + + removedDomains = append(removedDomains, currentDomain) + if err := m.RemoveDomain(currentDomain); err != nil { + log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) + } + } + + return removedDomains +} + +// mergeServerDomains merges new server domains with existing ones. +// When a domain type is provided in the new domains, it completely replaces that type. +func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains { + merged := existing + + if incoming.Signal != "" { + merged.Signal = incoming.Signal + } + if len(incoming.Relay) > 0 { + merged.Relay = incoming.Relay + } + if incoming.Flow != "" { + merged.Flow = incoming.Flow + } + if len(incoming.Stuns) > 0 { + merged.Stuns = incoming.Stuns + } + if len(incoming.Turns) > 0 { + merged.Turns = incoming.Turns + } + + return merged +} + +// isDomainInList checks if domain exists in the list +func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool { + for _, d := range list { + if domain.SafeString() == d.SafeString() { + return true + } + } + return false +} + +// isManagementDomain checks if domain is the protected management domain +func (m *Resolver) isManagementDomain(domain domain.Domain) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return m.mgmtDomain != nil && domain == *m.mgmtDomain +} + +// addNewDomains resolves and caches all domains from the update +func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { + for _, newDomain := range newDomains { + if err := m.AddDomain(ctx, newDomain); err != nil { + log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) + } else { + log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) + } + } +} + +func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List { + var domains domain.List + + if serverDomains.Signal != "" { + domains = append(domains, serverDomains.Signal) + } + + for _, relay := range serverDomains.Relay { + if relay != "" { + domains = append(domains, relay) + } + } + + if serverDomains.Flow != "" { + domains = append(domains, serverDomains.Flow) + } + + for _, stun := range serverDomains.Stuns { + if stun != "" { + domains = append(domains, stun) + } + } + + for _, turn := range serverDomains.Turns { + if turn != "" { + domains = append(domains, turn) + } + } + + return domains +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go new file mode 100644 index 000000000..99d289871 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -0,0 +1,416 @@ +package mgmt + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +func TestResolver_NewResolver(t *testing.T) { + resolver := NewResolver() + + assert.NotNil(t, resolver) + assert.NotNil(t, resolver.records) + assert.False(t, resolver.MatchSubdomains()) +} + +func TestResolver_ExtractDomainFromURL(t *testing.T) { + tests := []struct { + name string + urlStr string + expectedDom string + expectError bool + }{ + { + name: "HTTPS URL with port", + urlStr: "https://api.netbird.io:443", + expectedDom: "api.netbird.io", + expectError: false, + }, + { + name: "HTTP URL without port", + urlStr: "http://signal.example.com", + expectedDom: "signal.example.com", + expectError: false, + }, + { + name: "URL with path", + urlStr: "https://relay.netbird.io/status", + expectedDom: "relay.netbird.io", + expectError: false, + }, + { + name: "Invalid URL", + urlStr: "not-a-valid-url", + expectedDom: "not-a-valid-url", + expectError: false, + }, + { + name: "Empty URL", + urlStr: "", + expectedDom: "", + expectError: true, + }, + { + name: "STUN URL", + urlStr: "stun:stun.example.com:3478", + expectedDom: "stun.example.com", + expectError: false, + }, + { + name: "TURN URL", + urlStr: "turn:turn.example.com:3478", + expectedDom: "turn.example.com", + expectError: false, + }, + { + name: "REL URL", + urlStr: "rel://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + { + name: "RELS URL", + urlStr: "rels://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parsedURL *url.URL + var err error + + if tt.urlStr != "" { + parsedURL, err = url.Parse(tt.urlStr) + if err != nil && !tt.expectError { + t.Fatalf("Failed to parse URL: %v", err) + } + } + + domain, err := extractDomainFromURL(parsedURL) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedDom, domain.SafeString()) + } + }) + } +} + +func TestResolver_PopulateFromConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := NewResolver() + + // Test with IP address - should return error since IP addresses are rejected + mgmtURL, _ := url.Parse("https://127.0.0.1") + + err := resolver.PopulateFromConfig(ctx, mgmtURL) + assert.Error(t, err) + assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed) + + // No domains should be cached when using IP addresses + domains := resolver.GetCachedDomains() + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") +} + +func TestResolver_ServeDNS(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Add a test domain to the cache - use example.org which is reserved for testing + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Test A record query for cached domain + t.Run("Cached domain A record", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer") + }) + + // Test uncached domain signals to continue to next handler + t.Run("Uncached domain signals continue to next handler", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("unknown.example.com.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + // Zero flag set to true signals the handler chain to continue to next handler + assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain") + }) + + // Test that subdomains of cached domains are NOT resolved + t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query for a subdomain of our cached domain + req := new(dns.Msg) + req.SetQuestion("sub.example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains") + }) + + // Test case-insensitive matching + t.Run("Case-insensitive domain matching", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query with different casing + req := new(dns.Msg) + req.SetQuestion("EXAMPLE.ORG.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case") + }) +} + +func TestResolver_GetCachedDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + cachedDomains := resolver.GetCachedDomains() + + assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain") + assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original") + assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot") +} + +func TestResolver_ManagementDomainProtection(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + mgmtURL, _ := url.Parse("https://example.org") + err := resolver.PopulateFromConfig(ctx, mgmtURL) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + initialDomains := resolver.GetCachedDomains() + if len(initialDomains) == 0 { + t.Skip("Management domain failed to resolve, skipping test") + } + assert.Equal(t, 1, len(initialDomains), "Should have management domain cached") + assert.Equal(t, "example.org", initialDomains[0].SafeString()) + + serverDomains := dnsconfig.ServerDomains{ + Signal: "google.com", + Relay: []domain.Domain{"cloudflare.com"}, + } + + _, err = resolver.UpdateFromServerDomains(ctx, serverDomains) + if err != nil { + t.Logf("Server domains update failed: %v", err) + } + + finalDomains := resolver.GetCachedDomains() + + managementStillCached := false + for _, d := range finalDomains { + if d.SafeString() == "example.org" { + managementStillCached = true + break + } + } + assert.True(t, managementStillCached, "Management domain should never be removed") +} + +// extractDomainFromURL extracts a domain from a URL - test helper function +func extractDomainFromURL(u *url.URL) (domain.Domain, error) { + if u == nil { + return "", fmt.Errorf("URL is nil") + } + return dnsconfig.ExtractValidDomain(u.String()) +} + +func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Verify domains were added + cachedDomains := resolver.GetCachedDomains() + assert.Len(t, cachedDomains, 3) + + // Update with empty ServerDomains (simulating partial network map update) + emptyDomains := dnsconfig.ServerDomains{} + removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains) + assert.NoError(t, err) + + // Verify no domains were removed + assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty") + + // Verify all original domains are still cached + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "All original domains should still be cached") +} + +func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn) + partialDomains := dnsconfig.ServerDomains{ + Signal: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Should remove only the old signal domain + assert.Len(t, removedDomains, 1, "Should remove only the old signal domain") + assert.Equal(t, "example.org", removedDomains[0].SafeString()) + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "github.com") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.NotContains(t, domainStrings, "example.org") +} + +func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only flow domain - new type, should preserve all existing) + partialDomains := dnsconfig.ServerDomains{ + Flow: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "example.org") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.Contains(t, domainStrings, "github.com") +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index d160fa99a..0f89b9016 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,20 +3,23 @@ package dns import ( "fmt" "net/netip" + "net/url" "github.com/miekg/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/shared/management/domain" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func(domain.List, dns.Handler, int) - DeregisterHandlerFunc func(domain.List, int) + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) + UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error } func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string { // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface func (m *MockServer) ProbeAvailability() { } + +func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + if m.UpdateServerConfigFunc != nil { + return m.UpdateServerConfigFunc(domains) + } + return nil +} + +func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { + return nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index cbcf6a256..8cb886203 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "net/url" "runtime" "strings" "sync" @@ -15,7 +16,9 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -45,6 +48,8 @@ type Server interface { OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string ProbeAvailability() + UpdateServerConfig(domains dnsconfig.ServerDomains) error + PopulateManagementDomain(mgmtURL *url.URL) error } type nsGroupsByDomain struct { @@ -77,6 +82,8 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int + mgmtCacheResolver *mgmt.Resolver + // permanent related properties permanent bool hostsDNSHolder *hostsDNSHolder @@ -104,18 +111,20 @@ type handlerWrapper struct { type registeredHandlerMap map[types.HandlerID]handlerWrapper +// DefaultServerConfig holds configuration parameters for NewDefaultServer +type DefaultServerConfig struct { + WgInterface WGIface + CustomAddress string + StatusRecorder *peer.Status + StateManager *statemanager.Manager + DisableSys bool +} + // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, - stateManager *statemanager.Manager, - disableSys bool, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) { var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if config.CustomAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) if err != nil { return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) } @@ -123,13 +132,14 @@ func NewDefaultServer( } var dnsService service - if wgInterface.IsUserspaceBind() { - dnsService = NewServiceViaMemory(wgInterface) + if config.WgInterface.IsUserspaceBind() { + dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(wgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil + server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) + return server, nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -178,20 +188,24 @@ func newDefaultServer( ) *DefaultServer { handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) + + mgmtCacheResolver := mgmt.NewResolver() + defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: local.NewResolver(), - wgInterface: wgInterface, - statusRecorder: statusRecorder, - stateManager: stateManager, - hostsDNSHolder: newHostsDNSHolder(), - hostManager: &noopHostConfigurator{}, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), + wgInterface: wgInterface, + statusRecorder: statusRecorder, + stateManager: stateManager, + hostsDNSHolder: newHostsDNSHolder(), + hostManager: &noopHostConfigurator{}, + mgmtCacheResolver: mgmtCacheResolver, } // register with root zone, handler chain takes care of the routing @@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { - log.Debugf("registering handler %s with priority %d", handler, priority) + log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains) for _, domain := range domains { if domain == "" { @@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { - log.Debugf("deregistering handler %v with priority %d", domains, priority) + log.Debugf("deregistering handler with priority %d for %v", priority, domains) for _, domain := range domains { if domain == "" { @@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } +func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + s.mux.Lock() + defer s.mux.Unlock() + + if s.mgmtCacheResolver != nil { + removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) + if err != nil { + return fmt.Errorf("update management cache resolver: %w", err) + } + + if len(removedDomains) > 0 { + s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) + } + + newDomains := s.mgmtCacheResolver.GetCachedDomains() + if len(newDomains) > 0 { + s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) + } + } + + return nil +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver if update.ServiceEnable { @@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain { ), ) } + +// PopulateManagementDomain populates the DNS cache with management domain +func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { + if s.mgmtCacheResolver != nil { + return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) + } + return nil +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 068f001d8..11575d500 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatal(err) } @@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: &mocWGIface{}, + CustomAddress: testCase.addrPort, + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 89d637686..6ef0ab526 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type ServiceViaMemory struct { diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 071e3617a..c19e0acb5 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) { } const ( - UpstreamTimeout = 15 * time.Second + UpstreamTimeout = 4 * time.Second + // ClientTimeout is the timeout for the dns.Client. + // Set longer than UpstreamTimeout to ensure context timeout takes precedence + ClientTimeout = 5 * time.Second - failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second ) @@ -58,9 +60,7 @@ type upstreamResolverBase struct { upstreamServers []netip.AddrPort domain string disabled bool - failsCount atomic.Int32 successCount atomic.Int32 - failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration @@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain: domain, upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.upstreamServers) } // ID returns the unique handler ID @@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { requestID := GenerateRequestID() logger := log.WithField("request_id", requestID) - var err error - defer func() { - u.checkUpstreamFails(err) - }() logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + u.prepareRequest(r) + + if u.ctx.Err() != nil { + logger.Tracef("%s has been stopped", u) + return + } + + if u.tryUpstreamServers(w, r, logger) { + return + } + + u.writeErrorResponse(w, r, logger) +} + +func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } +} - select { - case <-u.ctx.Done(): - logger.Tracef("%s has been stopped", u) - return - default: +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { + timeout := u.upstreamTimeout + if len(u.upstreamServers) > 1 { + maxTotal := 5 * time.Second + minPerUpstream := 2 * time.Second + scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) + if scaledTimeout > minPerUpstream { + timeout = scaledTimeout + } else { + timeout = minPerUpstream + } } for _, upstream := range u.upstreamServers { - var rm *dns.Msg - var t time.Duration - - func() { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - defer cancel() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) - }() - - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) - continue - } - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) - continue + if u.queryUpstream(w, r, upstream, timeout, logger) { + return true } + } + return false +} - if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - continue - } +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { + var rm *dns.Msg + var t time.Duration + var err error - u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + var startTime time.Time + func() { + ctx, cancel := context.WithTimeout(u.ctx, timeout) + defer cancel() + startTime = time.Now() + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + }() - if err = w.WriteMsg(rm); err != nil { - logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) - } - // count the fails only if they happen sequentially - u.failsCount.Store(0) + if err != nil { + u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) + return false + } + + if rm == nil || !rm.Response { + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + return false + } + + return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) +} + +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { + if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) return } - u.failsCount.Add(1) + + elapsed := time.Since(startTime) + timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { + timeoutMsg += " " + peerInfo + } + timeoutMsg += fmt.Sprintf(" - error: %v", err) + logger.Warnf(timeoutMsg) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { + u.successCount.Add(1) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) + + if err := w.WriteMsg(rm); err != nil { + logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) + } + return true +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) @@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -// checkUpstreamFails counts fails and disables or enables upstream resolving -// -// If fails count is greater that failsTillDeact, upstream resolving -// will be disabled for reactivatePeriod, after that time period fails counter -// will be reset and upstream will be reactivated. -func (u *upstreamResolverBase) checkUpstreamFails(err error) { - u.mutex.Lock() - defer u.mutex.Unlock() - - if u.failsCount.Load() < u.failsTillDeact || u.disabled { - return - } - - select { - case <-u.ctx.Done(): - return - default: - } - - u.disable(err) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (fail count exceeded)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - // TODO add domain meta - ) -} - // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() { default: } - // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact - if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { + // avoid probe if upstreams could resolve at least one query + if u.successCount.Load() > 0 { return } @@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() { } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.failsCount.Store(0) u.successCount.Add(1) u.reactivate() u.disabled = false @@ -416,3 +423,80 @@ func GenerateRequestID() string { } return hex.EncodeToString(bytes) } + +// FormatPeerStatus formats peer connection status information for debugging DNS timeouts +func FormatPeerStatus(peerState *peer.State) string { + isConnected := peerState.ConnStatus == peer.StatusConnected + hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() && + time.Since(peerState.LastWireguardHandshake) < 3*time.Minute + + statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP) + + switch { + case !isConnected: + statusInfo += " DISCONNECTED" + case !hasRecentHandshake: + statusInfo += " NO_RECENT_HANDSHAKE" + default: + statusInfo += " connected" + } + + if !peerState.LastWireguardHandshake.IsZero() { + timeSinceHandshake := time.Since(peerState.LastWireguardHandshake) + statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second)) + } else { + statusInfo += " no_handshake" + } + + if peerState.Relayed { + statusInfo += " via_relay" + } + + if peerState.Latency > 0 { + statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency) + } + + return statusInfo +} + +// findPeerForIP finds which peer handles the given IP address +func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { + if statusRecorder == nil { + return nil + } + + fullStatus := statusRecorder.GetFullStatus() + var bestMatch *peer.State + var bestPrefixLen int + + for _, peerState := range fullStatus.Peers { + routes := peerState.GetRoutes() + for route := range routes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + continue + } + + if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen { + peerStateCopy := peerState + bestMatch = &peerStateCopy + bestPrefixLen = prefix.Bits() + } + } + } + + return bestMatch +} + +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ddbf84ae4..def281f28 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type upstreamResolver struct { @@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns } func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{} + upstreamExchangeClient := &dns.Client{ + Timeout: ClientTimeout, + } return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } @@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri } upstreamExchangeClient := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: timeout, } - return upstreamExchangeClient.Exchange(r, upstream) + return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 317588a27..434e5880b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -34,7 +34,10 @@ func newUpstreamResolver( } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) + client := &dns.Client{ + Timeout: ClientTimeout, + } + return ExchangeWithFallback(ctx, client, r, upstream) } func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 96b8bbb0f..eadcdd117 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -47,7 +47,9 @@ func newUpstreamResolver( } func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - client := &dns.Client{} + client := &dns.Client{ + Timeout: ClientTimeout, + } upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) @@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura }, } client := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: dialTimeout, } return client, nil } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 51d870e2a..e1573e75e 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) } func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { + mockClient := &mockUpstreamResolver{ + err: dns.ErrTime, + r: new(dns.Msg), + rtt: time.Millisecond, + } + resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: &mockUpstreamResolver{ - err: nil, - r: new(dns.Msg), - rtt: time.Millisecond, - }, + ctx: context.TODO(), + upstreamClient: mockClient, upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, + reactivatePeriod: time.Microsecond * 100, } addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - resolver.failsTillDeact = 0 - resolver.reactivatePeriod = time.Microsecond * 100 - - responseWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { return nil }, - } failed := false resolver.deactivate = func(error) { failed = true + // After deactivation, make the mock client work again + mockClient.err = nil } reactivated := false @@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) + resolver.ProbeAvailability() if !failed { t.Errorf("expected that resolving was deactivated") @@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { return } - if resolver.failsCount.Load() != 0 { - t.Errorf("fails count after reactivation should be 0") - return - } - if resolver.disabled { t.Errorf("should be enabled") } diff --git a/client/internal/engine.go b/client/internal/engine.go index 10f709dfa..61d109219 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "net/netip" + "net/url" "os" "runtime" "slices" @@ -16,8 +17,8 @@ import ( "time" "github.com/hashicorp/go-multierror" - "github.com/pion/ice/v3" - "github.com/pion/stun/v2" + "github.com/pion/ice/v4" + "github.com/pion/stun/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -27,10 +28,11 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -165,7 +167,7 @@ type Engine struct { wgInterface WGIface - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -196,6 +198,10 @@ type Engine struct { latestSyncResponse *mgmProto.SyncResponse connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + wgIfaceMonitorWg sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -344,13 +350,16 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + // Stop WireGuard interface monitor and wait for it to exit + e.wgIfaceMonitorWg.Wait() + return nil } // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service -func (e *Engine) Start() error { +func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -406,6 +415,11 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer + // Populate DNS cache with NetbirdConfig and management URL for early resolution + if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache: %v", err) + } + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ Context: e.ctx, PublicKey: e.config.WgPrivateKey.PublicKey().String(), @@ -459,7 +473,7 @@ func (e *Engine) Start() error { StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), } @@ -475,6 +489,22 @@ func (e *Engine) Start() error { // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() + + // monitor WireGuard interface lifecycle and restart engine on changes + e.wgIfaceMonitor = NewWGIfaceMonitor() + e.wgIfaceMonitorWg.Add(1) + + go func() { + defer e.wgIfaceMonitorWg.Done() + + if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { + log.Infof("WireGuard interface monitor: %s, restarting engine", err) + e.restartEngine() + } else if err != nil { + log.Warnf("WireGuard interface monitor: %s", err) + } + }() + return nil } @@ -666,6 +696,30 @@ func (e *Engine) removePeer(peerKey string) error { return nil } +// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response +func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { + if e.dnsServer == nil { + return nil + } + + // Populate management URL if provided + if mgmtURL != nil { + if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache with management URL: %v", err) + } + } + + // Populate NetbirdConfig domains if provided + if netbirdConfig != nil { + serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig) + if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { + return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err) + } + } + + return nil +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -697,6 +751,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } + if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { + log.Warnf("Failed to update DNS server config: %v", err) + } + // todo update signal } @@ -867,7 +925,6 @@ func (e *Engine) receiveManagementEvents() { e.config.EnableSSHRemotePortForwarding, ) - // err = e.mgmClient.Sync(info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. @@ -878,7 +935,7 @@ func (e *Engine) receiveManagementEvents() { } log.Debugf("stopped receiving updates from Management Service") }() - log.Debugf("connecting to Management Service updates stream") + log.Infof("connecting to Management Service updates stream") } func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { @@ -1253,7 +1310,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), }, @@ -1515,7 +1572,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) + + dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ + WgInterface: e.wgInterface, + CustomAddress: e.config.CustomDNSAddress, + StatusRecorder: e.statusRecorder, + StateManager: e.stateManager, + DisableSys: e.config.DisableDNS, + }) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index ce805c776..e82a96abf 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -19,21 +19,18 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" @@ -45,9 +42,12 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -85,7 +85,7 @@ type MockWGIface struct { NameFunc func() string AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface - UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error @@ -135,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface { return m.ToInterfaceFunc() } -func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } @@ -244,7 +244,7 @@ func TestEngine_SSH(t *testing.T) { UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } - err = engine.Start() + err = engine.Start(nil, nil) require.NoError(t, err) defer func() { @@ -440,7 +440,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) + engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) @@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) { } }() - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) return @@ -1095,7 +1095,7 @@ func TestEngine_MultiplePeers(t *testing.T) { defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) wg.Done() @@ -1581,7 +1581,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + permissionsManager := permissions.NewManager(store) + peersManager := peers.NewManager(store, permissionsManager) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1598,7 +1602,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri Return(&types.ExtraSettings{}, nil). AnyTimes() - permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index bf96153ea..690fdb7cc 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -9,9 +9,9 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -24,7 +24,7 @@ type wgIfaceBase interface { Name() string Address() wgaddr.Address ToInterface() *net.Interface - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/login.go b/client/internal/login.go index ffabacf4a..28d45e49c 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, return false, err } - _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) + _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) if isLoginNeeded(err) { return true, nil } @@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string, return err } - serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) + serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) if serverKey != nil && isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) + if err != nil { + return err + } + } else if err != nil { return err } - return err + return nil } func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { @@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err + return nil, nil, err } sysInfo := system.GetInfo(ctx) @@ -125,8 +129,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.EnableSSHLocalPortForwarding, config.EnableSSHRemotePortForwarding, ) - _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, err + loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) + return serverKey, loginResp, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index dbb4747a5..a4ffa3a25 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,7 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const defaultChannelSize = 100 diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index a6cf3cd25..86e4596d4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -6,12 +6,11 @@ import ( "math/rand" "net" "net/netip" - "os" "runtime" "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) - if os.Getenv("NB_FORCE_RELAY") != "true" { + if !isForceRelayed() { conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go new file mode 100644 index 000000000..32a458d00 --- /dev/null +++ b/client/internal/peer/env.go @@ -0,0 +1,14 @@ +package peer + +import ( + "os" + "strings" +) + +const ( + EnvKeyNBForceRelay = "NB_FORCE_RELAY" +) + +func isForceRelayed() bool { + return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index b9c9aa134..70850e6eb 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 3cbf74cfd..42eaea683 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -43,13 +43,6 @@ type OfferAnswer struct { SessionID *ICESessionID } -func (oa *OfferAnswer) SessionIDString() string { - if oa.SessionID == nil { - return "unknown" - } - return oa.SessionID.String() -} - type Handshaker struct { mu sync.Mutex log *log.Entry @@ -57,7 +50,7 @@ type Handshaker struct { signaler *Signaler ice *WorkerICE relay *WorkerRelay - onNewOfferListeners []func(*OfferAnswer) + onNewOfferListeners []*OfferListener // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - h.onNewOfferListeners = append(h.onNewOfferListeners, offer) + l := NewOfferListener(offer) + h.onNewOfferListeners = append(h.onNewOfferListeners, l) } func (h *Handshaker) Listen(ctx context.Context) { @@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } for _, listener := range h.onNewOfferListeners { - listener(&remoteOfferAnswer) + listener.Notify(&remoteOfferAnswer) } h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) for _, listener := range h.onNewOfferListeners { - listener(&remoteOfferAnswer) + listener.Notify(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go new file mode 100644 index 000000000..e2d3f3f38 --- /dev/null +++ b/client/internal/peer/handshaker_listener.go @@ -0,0 +1,62 @@ +package peer + +import ( + "sync" +) + +type callbackFunc func(remoteOfferAnswer *OfferAnswer) + +func (oa *OfferAnswer) SessionIDString() string { + if oa.SessionID == nil { + return "unknown" + } + return oa.SessionID.String() +} + +type OfferListener struct { + fn callbackFunc + running bool + latest *OfferAnswer + mu sync.Mutex +} + +func NewOfferListener(fn callbackFunc) *OfferListener { + return &OfferListener{ + fn: fn, + } +} + +func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { + o.mu.Lock() + defer o.mu.Unlock() + + // Store the latest offer + o.latest = remoteOfferAnswer + + // If already running, the running goroutine will pick up this latest value + if o.running { + return + } + + // Start processing + o.running = true + + // Process in a goroutine to avoid blocking the caller + go func(remoteOfferAnswer *OfferAnswer) { + for { + o.fn(remoteOfferAnswer) + + o.mu.Lock() + if o.latest == nil { + // No more work to do + o.running = false + o.mu.Unlock() + return + } + remoteOfferAnswer = o.latest + // Clear the latest to mark it as being processed + o.latest = nil + o.mu.Unlock() + } + }(remoteOfferAnswer) +} diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go new file mode 100644 index 000000000..8363741a5 --- /dev/null +++ b/client/internal/peer/handshaker_listener_test.go @@ -0,0 +1,39 @@ +package peer + +import ( + "testing" + "time" +) + +func Test_newOfferListener(t *testing.T) { + dummyOfferAnswer := &OfferAnswer{} + runChan := make(chan struct{}, 10) + + longRunningFn := func(remoteOfferAnswer *OfferAnswer) { + time.Sleep(1 * time.Second) + runChan <- struct{}{} + } + + hl := NewOfferListener(longRunningFn) + + hl.Notify(dummyOfferAnswer) + hl.Notify(dummyOfferAnswer) + hl.Notify(dummyOfferAnswer) + + // Wait for exactly 2 callbacks + for i := 0; i < 2; i++ { + select { + case <-runChan: + case <-time.After(3 * time.Second): + t.Fatal("Timeout waiting for callback") + } + } + + // Verify no additional callbacks happen + select { + case <-runChan: + t.Fatal("Unexpected additional callback") + case <-time.After(100 * time.Millisecond): + t.Log("Correctly received exactly 2 callbacks") + } +} diff --git a/client/internal/peer/ice/StunTurn.go b/client/internal/peer/ice/StunTurn.go index 63ee8c713..a389f5444 100644 --- a/client/internal/peer/ice/StunTurn.go +++ b/client/internal/peer/ice/StunTurn.go @@ -3,7 +3,7 @@ package ice import ( "sync/atomic" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" ) type StunTurn atomic.Value diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 58c1bf634..e80c98884 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -4,7 +4,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/logging" "github.com/pion/randutil" log "github.com/sirupsen/logrus" diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go index dd854a605..dd5d67403 100644 --- a/client/internal/peer/ice/config.go +++ b/client/internal/peer/ice/config.go @@ -1,7 +1,7 @@ package ice import ( - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" ) type Config struct { diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index ca1d421a5..b28906625 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -1,7 +1,7 @@ package peer import ( - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 218872c15..0ed200fda 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,9 +30,10 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + enabledTime time.Time } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() + w.enabledTime = time.Now() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") @@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex onDisconnectedFn() return } + if lastHandshake.IsZero() { + elapsed := handshake.Sub(w.enabledTime).Seconds() + w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + } + lastHandshake = *handshake resetTime := time.Until(handshake.Add(checkPeriod)) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4f00af829..eb886a4d3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -8,12 +8,11 @@ import ( "sync" "time" - "github.com/pion/ice/v3" - "github.com/pion/stun/v2" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -55,10 +54,6 @@ type WorkerICE struct { sessionID ICESessionID muxAgent sync.Mutex - StunTurn []*stun.URI - - sentExtraSrflx bool - localUfrag string localPwd string @@ -122,7 +117,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Warnf("failed to close ICE agent: %s", err) } w.agent = nil - // todo consider to switch to Relay connection while establishing a new ICE connection } var preferredCandidateTypes []ice.CandidateType @@ -140,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.muxAgent.Unlock() return } - w.sentExtraSrflx = false w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true @@ -167,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA w.log.Errorf("error while handling remote candidate") return } + + if shouldAddExtraCandidate(candidate) { + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + + if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil { + w.log.Errorf("error while handling remote candidate") + return + } + } } func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { @@ -210,14 +218,12 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { + if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) { + w.onICESelectedCandidatePair(agent, c1, c2) + }); err != nil { return nil, err } - if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil { - return nil, fmt.Errorf("failed setting binding response callback: %w", err) - } - return agent, nil } @@ -253,6 +259,11 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.closeAgent(agent, w.agentDialerCancel) return } + if pair == nil { + w.log.Warnf("selected candidate pair is nil, cannot proceed") + w.closeAgent(agent, w.agentDialerCancel) + return + } if !isRelayCandidate(pair.Local) { // dynamically set remote WireGuard port if other side specified a different one from the default one @@ -327,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) return } - mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault) if !ok { w.log.Warn("invalid udp mux conversion") return @@ -354,31 +365,23 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() - - if !w.shouldSendExtraSrflxCandidate(candidate) { - return - } - - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - w.log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - w.sentExtraSrflx = true - - go func() { - err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) - if err != nil { - w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) - } - }() } -func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { +func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) + + pairStat, ok := agent.GetSelectedCandidatePairStats() + if !ok { + w.log.Warnf("failed to get selected candidate pair stats") + return + } + + duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second)) + if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { + w.log.Debugf("failed to update latency for peer: %s", err) + return + } } func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { @@ -388,7 +391,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected return - case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: + // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to + // notify the conn.onICEStateDisconnected changes to update the current used priority + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() @@ -400,32 +406,34 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia } } -func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) { - if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil { - w.log.Debugf("failed to update latency for peer: %s", err) - return - } -} - -func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { - if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { - return true - } - return false -} - func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { - isControlling := w.config.LocalKey > w.config.Key - if isControlling { - return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + if isController(w.config) { + return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } } +func shouldAddExtraCandidate(candidate ice.Candidate) bool { + if candidate.Type() != ice.CandidateTypeServerReflexive { + return false + } + + if candidate.Port() == candidate.RelatedAddress().Port { + return false + } + + // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates + // in newer version we generate locally the extra candidate + if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok { + return false + } + return true +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() - return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ Network: candidate.NetworkType().String(), Address: candidate.Address(), Port: relatedAdd.Port, @@ -433,6 +441,21 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive RelAddr: relatedAdd.Address, RelPort: relatedAdd.Port, }) + if err != nil { + return nil, err + } + + for _, e := range candidate.Extensions() { + // overwrite the original candidate ID with the new one to avoid candidate duplication + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = candidate.ID() + } + if err := ec.AddExtension(e); err != nil { + return nil, err + } + } + + return ec, nil } func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 6e1f83a9a..fa208716f 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -7,12 +7,12 @@ import ( "sync" "time" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ProbeResult holds the info about the result of a relay probe request diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index ba27df654..9069cdcc5 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -2,11 +2,13 @@ package dnsinterceptor import ( "context" + "errors" "fmt" "net/netip" "runtime" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -26,6 +28,8 @@ import ( "github.com/netbirdio/netbird/route" ) +const dnsTimeout = 8 * time.Second + type domainMap map[domain.Domain][]netip.Prefix type internalDNATer interface { @@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) if err != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return @@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + startTime := time.Now() + reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if err != nil { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + if errors.Is(err, context.DeadlineExceeded) { + elapsed := time.Since(startTime) + peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) + logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", + elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) + } else { + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + } if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { logger.Errorf("failed writing DNS response: %v", err) } @@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR } return } + +func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { + if d.statusRecorder == nil { + return "" + } + + peerState, err := d.statusRecorder.GetPeer(peerKey) + if err != nil { + return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err) + } + + return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState)) +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index a6775c45a..04513bbe4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,9 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager { notifier := notifier.NewNotifier() sysOps := systemops.NewSysOps(config.WGInterface, notifier) + if runtime.GOOS == "windows" && config.WGInterface != nil { + nbnet.SetVPNInterfaceName(config.WGInterface.Name()) + } + dm := &DefaultManager{ ctx: mCTX, stop: cancel, @@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error { return nil } - if err := m.sysOps.CleanupRouting(nil); err != nil { + if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error { ips := resolveURLsToIPs(initialAddresses) - if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil { return fmt.Errorf("setup routing: %w", err) } @@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { - if err := m.sysOps.CleanupRouting(stateManager); err != nil { + if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") } + + if runtime.GOOS == "windows" { + nbnet.SetVPNInterfaceName("") + } } m.mux.Lock() diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index a375ce832..7cb8dae93 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -12,11 +12,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 128afa2a5..26a548634 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -3,7 +3,6 @@ package systemops import ( - "context" "errors" "fmt" "net" @@ -22,7 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net/hooks" ) const localSubnetsCacheTTL = 15 * time.Minute @@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() + hooks.RemoveWriteHooks() + hooks.RemoveCloseHooks() + hooks.RemoveAddressRemoveHooks() if err := r.refCounter.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) @@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := util.GetPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - + beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } @@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } - afterHook := func(connID nbnet.ConnectionID) error { + afterHook := func(connID hooks.ConnectionID) error { if err := r.refCounter.DecrementWithID(string(connID)); err != nil { return fmt.Errorf("remove route reference: %w", err) } @@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M var merr *multierror.Error for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err)) + continue + } + if err := beforeHook("init", prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err)) } } - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } + hooks.AddWriteHook(beforeHook) + hooks.AddCloseHook(afterHook) - var merr *multierror.Error - for _, ip := range resolvedIPs { - merr = multierror.Append(merr, beforeHook(connID, ip.IP)) - } - return nberrors.FormatErrorOrNil(merr) - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.Decrement(prefix); err != nil { return fmt.Errorf("remove route reference: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index c1c1182bc..32ea38a7a 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" ) type dialer interface { @@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) intf, err := net.InterfaceByName(wgInterface.Name()) @@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 10356eae0..99a363371 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -12,14 +12,14 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index c0cef94ba..bd10f131f 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // IPRule contains IP rule information for debugging @@ -94,15 +94,15 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { - if !nbnet.AdvancedRouting() { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) { + if !advancedRouting { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } defer func() { if err != nil { - if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { - if !nbnet.AdvancedRouting() { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if !advancedRouting { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index f165f7779..d43c2d5bf 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -20,11 +20,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index ad37f611f..959c697e4 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type PacketExpectation struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 4f836897b..7bce6af80 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "runtime/debug" + "sort" "strconv" "sync" "syscall" @@ -19,9 +20,16 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" ) -const InfiniteLifetime = 0xffffffff +func init() { + nbnet.GetBestInterfaceFunc = GetBestInterface +} + +const ( + InfiniteLifetime = 0xffffffff +) type RouteUpdateType int @@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct { Table [1]MIB_IPFORWARD_ROW2 // Flexible array member } +// candidateRoute represents a potential route for selection during route lookup +type candidateRoute struct { + interfaceIndex uint32 + prefixLength uint8 + routeMetric uint32 + interfaceMetric int +} + // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -177,11 +193,20 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + return r.cleanupRefCounter(stateManager) } @@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { if table != nil { - ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) - if ret != 0 { - log.Warnf("FreeMibTable failed with return code: %d", ret) - } + _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) } } @@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { entryPtr := basePtr + uintptr(i)*entrySize entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - detailed := buildWindowsDetailedRoute(entry) - if detailed != nil { + if detailed := buildWindowsDetailedRoute(entry); detailed != nil { detailedRoutes = append(detailedRoutes, *detailed) } } @@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { return ip } +// parseCandidatesFromTable extracts all matching candidate routes from the routing table +func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute { + var candidates []candidateRoute + entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) + basePtr := uintptr(unsafe.Pointer(&table.Table[0])) + + for i := uint32(0); i < table.NumEntries; i++ { + entryPtr := basePtr + uintptr(i)*entrySize + entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) + + if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil { + candidates = append(candidates, *candidate) + } + } + + return candidates +} + +// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry +// Returns nil if the route doesn't match the destination or should be skipped +func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute { + if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex { + return nil + } + + destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) + if !destPrefix.IsValid() || !destPrefix.Contains(dest) { + return nil + } + + interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family) + + return &candidateRoute{ + interfaceIndex: entry.InterfaceIndex, + prefixLength: entry.DestinationPrefix.PrefixLength, + routeMetric: entry.Metric, + interfaceMetric: interfaceMetric, + } +} + // getInterfaceMetric retrieves the interface metric for a given interface and address family func getInterfaceMetric(interfaceIndex uint32, family int16) int { if interfaceIndex == 0 { @@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int { return int(ipInterfaceRow.Metric) } +// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric +func sortRouteCandidates(candidates []candidateRoute) { + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].prefixLength != candidates[j].prefixLength { + return candidates[i].prefixLength > candidates[j].prefixLength + } + if candidates[i].routeMetric != candidates[j].routeMetric { + return candidates[i].routeMetric < candidates[j].routeMetric + } + return candidates[i].interfaceMetric < candidates[j].interfaceMetric + }) +} + +// GetBestInterface finds the best interface for reaching a destination, +// excluding the VPN interface to avoid routing loops. +// +// Route selection priority: +// 1. Longest prefix match (most specific route) +// 2. Lowest route metric +// 3. Lowest interface metric +func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { + var skipInterfaceIndex int + if vpnIntf != "" { + if iface, err := net.InterfaceByName(vpnIntf); err == nil { + skipInterfaceIndex = iface.Index + } else { + // not critical, if we cannot get ahold of the interface then we won't need to skip it + log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) + } + } + + table, err := getWindowsRoutingTable() + if err != nil { + return nil, fmt.Errorf("get routing table: %w", err) + } + defer freeWindowsRoutingTable(table) + + candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex) + + if len(candidates) == 0 { + return nil, fmt.Errorf("no route to %s", dest) + } + + // Sort routes: prefix length -> route metric -> interface metric + sortRouteCandidates(candidates) + + for _, candidate := range candidates { + iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex)) + if err != nil { + log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err) + continue + } + + if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() { + continue + } + + if iface.Flags&net.FlagUp == 0 { + log.Debugf("interface %s is down, trying next route", iface.Name) + continue + } + + log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d", + dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric) + return iface, nil + } + + return nil, fmt.Errorf("no usable interface found for %s", dest) +} + // formatRouteAge formats the route age in seconds to a human-readable string func formatRouteAge(ageSeconds uint32) string { if ageSeconds == 0 { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 523bd0b0d..3561adec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) var ( diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go index ac5a48e37..57ea32f69 100644 --- a/client/internal/routemanager/util/ip.go +++ b/client/internal/routemanager/util/ip.go @@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { if !ok { return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) } + addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) + prefix := netip.PrefixFrom(addr, addr.BitLen()) return prefix, nil } diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index e80adb42b..8961eaa69 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -5,7 +5,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // Dial connects to the address on the named network. diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index 9ce0a5556..d3be1896f 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -6,7 +6,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ListenPacket listens for incoming packets on the given network and address. diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 171cc42cb..4b031c05c 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri if netstack.IsEnabled() { n.iFaceDiscover = pionDiscover{} } else { - newMobileIFaceDiscover(iFaceDiscover) + n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover) } return n, n.UpdateInterfaces() } diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go new file mode 100644 index 000000000..78d70c15b --- /dev/null +++ b/client/internal/wg_iface_monitor.go @@ -0,0 +1,98 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine +// if the interface is deleted externally while the engine is running. +type WGIfaceMonitor struct { + done chan struct{} +} + +// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. +func NewWGIfaceMonitor() *WGIfaceMonitor { + return &WGIfaceMonitor{ + done: make(chan struct{}), + } +} + +// Start begins monitoring the WireGuard interface. +// It relies on the provided context cancellation to stop. +func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { + defer close(m.done) + + // Skip on mobile platforms as they handle interface lifecycle differently + if runtime.GOOS == "android" || runtime.GOOS == "ios" { + log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) + return false, errors.New("not supported on mobile platforms") + } + + if ifaceName == "" { + log.Debugf("Interface monitor: empty interface name, skipping monitor") + return false, errors.New("empty interface name") + } + + // Get initial interface index to track the specific interface instance + expectedIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) + return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) + } + + log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } + +} + +// getInterfaceIndex returns the index of a network interface by name. +// Returns an error if the interface is not found. +func getInterfaceIndex(name string) (int, error) { + if name == "" { + return 0, fmt.Errorf("empty interface name") + } + ifi, err := net.InterfaceByName(name) + if err != nil { + // Check if it's specifically a "not found" error + if errors.Is(err, &net.OpError{}) { + // On some systems, this might be a "not found" error + return 0, fmt.Errorf("interface not found: %w", err) + } + return 0, fmt.Errorf("failed to lookup interface: %w", err) + } + if ifi == nil { + return 0, fmt.Errorf("interface not found") + } + return ifi.Index, nil +} diff --git a/client/net/conn.go b/client/net/conn.go new file mode 100644 index 000000000..918e7f628 --- /dev/null +++ b/client/net/conn.go @@ -0,0 +1,49 @@ +//go:build !ios + +package net + +import ( + "io" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/net/hooks" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID hooks.ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +func (c *Conn) Close() error { + return closeConn(c.ID, c.Conn) +} + +// TCPConn wraps net.TCPConn to override its Close method to include hook functionality. +type TCPConn struct { + *net.TCPConn + ID hooks.ConnectionID +} + +// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +func (c *TCPConn) Close() error { + return closeConn(c.ID, c.TCPConn) +} + +// closeConn is a helper function to close connections and execute close hooks. +func closeConn(id hooks.ConnectionID, conn io.Closer) error { + err := conn.Close() + + closeHooks := hooks.GetCloseHooks() + for _, hook := range closeHooks { + if err := hook(id); err != nil { + log.Errorf("Error executing close hook: %v", err) + } + } + + return err +} diff --git a/client/net/dial.go b/client/net/dial.go new file mode 100644 index 000000000..041a00e5d --- /dev/null +++ b/client/net/dial.go @@ -0,0 +1,82 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + udpConn, ok := c.Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.TCPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + tcpConn, ok := c.Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn) + } + return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/dial_ios.go b/client/net/dial_ios.go similarity index 100% rename from util/net/dial_ios.go rename to client/net/dial_ios.go diff --git a/util/net/dialer.go b/client/net/dialer.go similarity index 99% rename from util/net/dialer.go rename to client/net/dialer.go index 0786c667e..29bec05a7 100644 --- a/util/net/dialer.go +++ b/client/net/dialer.go @@ -16,6 +16,5 @@ func NewDialer() *Dialer { Dialer: &net.Dialer{}, } dialer.init() - return dialer } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go new file mode 100644 index 000000000..2e1eb53d8 --- /dev/null +++ b/client/net/dialer_dial.go @@ -0,0 +1,87 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + + if CustomRoutingDisabled() || AdvancedRouting() { + return d.Dialer.DialContext(ctx, network, address) + } + + connID := hooks.GenerateConnID() + if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error { + if ctx.Err() != nil { + return ctx.Err() + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + + resolver := customResolver + if resolver == nil { + resolver = net.DefaultResolver + } + + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var merr *multierror.Error + for _, ip := range ips { + prefix, err := util.GetPrefixFromIP(ip.IP) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err)) + continue + } + for _, hook := range writeHooks { + if err := hook(connID, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err)) + } + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/dialer_init_android.go b/client/net/dialer_init_android.go similarity index 100% rename from util/net/dialer_init_android.go rename to client/net/dialer_init_android.go diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go new file mode 100644 index 000000000..18ebc6ad1 --- /dev/null +++ b/client/net/dialer_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (d *Dialer) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/dialer_init_linux.go b/client/net/dialer_init_linux.go similarity index 100% rename from util/net/dialer_init_linux.go rename to client/net/dialer_init_linux.go diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go new file mode 100644 index 000000000..6eefe5b1e --- /dev/null +++ b/client/net/dialer_init_windows.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyUnicastIFToSocket +} diff --git a/util/net/env.go b/client/net/env.go similarity index 94% rename from util/net/env.go rename to client/net/env.go index 32425665d..8f326ca88 100644 --- a/util/net/env.go +++ b/client/net/env.go @@ -11,6 +11,7 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) // CustomRoutingDisabled returns true if custom routing is disabled. diff --git a/client/net/env_android.go b/client/net/env_android.go new file mode 100644 index 000000000..9d89951a1 --- /dev/null +++ b/client/net/env_android.go @@ -0,0 +1,24 @@ +//go:build android + +package net + +// Init initializes the network environment for Android +func Init() { + // No initialization needed on Android +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on Android since we cannot handle routes dynamically. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on Android +func SetVPNInterfaceName(name string) { + // No-op on Android - not needed for Android VPN service +} + +// GetVPNInterfaceName returns empty string on Android +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_generic.go b/client/net/env_generic.go new file mode 100644 index 000000000..f467930c3 --- /dev/null +++ b/client/net/env_generic.go @@ -0,0 +1,23 @@ +//go:build !linux && !windows && !android + +package net + +// Init initializes the network environment (no-op on non-Linux/Windows platforms) +func Init() { + // No-op on non-Linux/Windows platforms +} + +// AdvancedRouting returns false on non-Linux/Windows platforms +func AdvancedRouting() bool { + return false +} + +// SetVPNInterfaceName is a no-op on non-Windows platforms +func SetVPNInterfaceName(name string) { + // No-op on non-Windows platforms +} + +// GetVPNInterfaceName returns empty string on non-Windows platforms +func GetVPNInterfaceName() string { + return "" +} diff --git a/util/net/env_linux.go b/client/net/env_linux.go similarity index 86% rename from util/net/env_linux.go rename to client/net/env_linux.go index 3159f6462..82d9a74a8 100644 --- a/util/net/env_linux.go +++ b/client/net/env_linux.go @@ -17,8 +17,7 @@ import ( const ( // these have the same effect, skip socket env supported for backward compatibility - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" - envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" ) var advancedRoutingSupported bool @@ -27,6 +26,7 @@ func Init() { advancedRoutingSupported = checkAdvancedRoutingSupport() } +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes func AdvancedRouting() bool { return advancedRoutingSupported } @@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool { } func CheckFwmarkSupport() bool { - // temporarily enable advanced routing to check fwmarks are supported + // temporarily enable advanced routing to check if fwmarks are supported old := advancedRoutingSupported advancedRoutingSupported = true defer func() { @@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool { } return true } + +// SetVPNInterfaceName is a no-op on Linux +func SetVPNInterfaceName(name string) { + // No-op on Linux - not needed for fwmark-based routing +} + +// GetVPNInterfaceName returns empty string on Linux +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_windows.go b/client/net/env_windows.go new file mode 100644 index 000000000..7e8868ba5 --- /dev/null +++ b/client/net/env_windows.go @@ -0,0 +1,67 @@ +//go:build windows + +package net + +import ( + "os" + "strconv" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +var ( + vpnInterfaceName string + vpnInitMutex sync.RWMutex + + advancedRoutingSupported bool +) + +func Init() { + advancedRoutingSupported = checkAdvancedRoutingSupport() +} + +func checkAdvancedRoutingSupport() bool { + var err error + var legacyRouting bool + if val := os.Getenv(envUseLegacyRouting); val != "" { + legacyRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + } + } + + if legacyRouting || netstack.IsEnabled() { + log.Info("advanced routing has been requested to be disabled") + return false + } + + log.Info("system supports advanced routing") + + return true +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes +func AdvancedRouting() bool { + return advancedRoutingSupported +} + +// GetVPNInterfaceName returns the stored VPN interface name +func GetVPNInterfaceName() string { + vpnInitMutex.RLock() + defer vpnInitMutex.RUnlock() + return vpnInterfaceName +} + +// SetVPNInterfaceName sets the VPN interface name for lazy initialization +func SetVPNInterfaceName(name string) { + vpnInitMutex.Lock() + defer vpnInitMutex.Unlock() + vpnInterfaceName = name + + if name != "" { + log.Infof("VPN interface name set to %s for route exclusion", name) + } +} diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go new file mode 100644 index 000000000..93d8e18ef --- /dev/null +++ b/client/net/hooks/hooks.go @@ -0,0 +1,93 @@ +package hooks + +import ( + "net/netip" + "slices" + "sync" + + "github.com/google/uuid" +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} + +type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error +type CloseHookFunc func(connID ConnectionID) error +type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + +var ( + hooksMutex sync.RWMutex + + writeHooks []WriteHookFunc + closeHooks []CloseHookFunc + addressRemoveHooks []AddressRemoveHookFunc +) + +// AddWriteHook allows adding a new hook to be executed before writing/dialing. +func AddWriteHook(hook WriteHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = append(writeHooks, hook) +} + +// AddCloseHook allows adding a new hook to be executed on connection close. +func AddCloseHook(hook CloseHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = append(closeHooks, hook) +} + +// RemoveWriteHooks removes all write hooks. +func RemoveWriteHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = nil +} + +// RemoveCloseHooks removes all close hooks. +func RemoveCloseHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = nil +} + +// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddAddressRemoveHook(hook AddressRemoveHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = append(addressRemoveHooks, hook) +} + +// RemoveAddressRemoveHooks removes all listener address hooks. +func RemoveAddressRemoveHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = nil +} + +// GetWriteHooks returns a copy of the current write hooks. +func GetWriteHooks() []WriteHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(writeHooks) +} + +// GetCloseHooks returns a copy of the current close hooks. +func GetCloseHooks() []CloseHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(closeHooks) +} + +// GetAddressRemoveHooks returns a copy of the current listener address remove hooks. +func GetAddressRemoveHooks() []AddressRemoveHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(addressRemoveHooks) +} diff --git a/client/net/listen.go b/client/net/listen.go new file mode 100644 index 000000000..da7262806 --- /dev/null +++ b/client/net/listen.go @@ -0,0 +1,47 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *PacketConn: + // Legacy routing: wrapped connection for hooks + udpConn, ok := c.PacketConn.(*net.UDPConn) + if !ok { + if err := c.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/listen_ios.go b/client/net/listen_ios.go similarity index 100% rename from util/net/listen_ios.go rename to client/net/listen_ios.go diff --git a/util/net/listener.go b/client/net/listener.go similarity index 81% rename from util/net/listener.go rename to client/net/listener.go index f4d769f58..4c2f53c05 100644 --- a/util/net/listener.go +++ b/client/net/listener.go @@ -7,14 +7,12 @@ import ( // ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before // responding via the socket and after closing. This can be used to bypass the VPN for listeners. type ListenerConfig struct { - *net.ListenConfig + net.ListenConfig } // NewListener creates a new ListenerConfig instance. func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } + listener := &ListenerConfig{} listener.init() return listener diff --git a/util/net/listener_init_android.go b/client/net/listener_init_android.go similarity index 100% rename from util/net/listener_init_android.go rename to client/net/listener_init_android.go diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go new file mode 100644 index 000000000..4f8f17ab2 --- /dev/null +++ b/client/net/listener_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (l *ListenerConfig) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/listener_init_linux.go b/client/net/listener_init_linux.go similarity index 100% rename from util/net/listener_init_linux.go rename to client/net/listener_init_linux.go diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go new file mode 100644 index 000000000..a9399b5f1 --- /dev/null +++ b/client/net/listener_init_windows.go @@ -0,0 +1,8 @@ +package net + +func (l *ListenerConfig) init() { + // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses. + // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case + // the interface will be selected that serves the default route. + l.ListenConfig.Control = applyUnicastIFToSocket +} diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go new file mode 100644 index 000000000..0bb5ad67d --- /dev/null +++ b/client/net/listener_listen.go @@ -0,0 +1,153 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() || AdvancedRouting() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := hooks.GenerateConnID() + + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.UDPConn) +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *PacketConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen()) + + addressRemoveHooks := hooks.GetAddressRemoveHooks() + if len(addressRemoveHooks) == 0 { + return + } + + for _, hook := range addressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + +// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality +func WrapPacketConn(conn net.PacketConn) net.PacketConn { + if AdvancedRouting() { + // hooks not required for advanced routing + return conn + } + return &PacketConn{ + PacketConn: conn, + ID: hooks.GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error { + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded { + return nil + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr) + } + + prefix, err := util.GetPrefixFromIP(udpAddr.IP) + if err != nil { + return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err) + } + + log.Debugf("Listener resolved IP for %s: %s", addr, prefix) + + var merr *multierror.Error + for _, hook := range writeHooks { + if err := hook(id, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/listener_listen_ios.go b/client/net/listener_listen_ios.go similarity index 100% rename from util/net/listener_listen_ios.go rename to client/net/listener_listen_ios.go diff --git a/util/net/net.go b/client/net/net.go similarity index 81% rename from util/net/net.go rename to client/net/net.go index fdcf4ee6a..a97de9d59 100644 --- a/util/net/net.go +++ b/client/net/net.go @@ -5,8 +5,6 @@ import ( "math/big" "net" "net/netip" - - "github.com/google/uuid" ) const ( @@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool { return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper } -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -type AddHookFunc func(connID ConnectionID, IP net.IP) error -type RemoveHookFunc func(connID ConnectionID) error - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} - func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP addr := network.Addr().AsSlice() diff --git a/util/net/net_linux.go b/client/net/net_linux.go similarity index 100% rename from util/net/net_linux.go rename to client/net/net_linux.go diff --git a/util/net/net_test.go b/client/net/net_test.go similarity index 100% rename from util/net/net_test.go rename to client/net/net_test.go diff --git a/client/net/net_windows.go b/client/net/net_windows.go new file mode 100644 index 000000000..649d83aaf --- /dev/null +++ b/client/net/net_windows.go @@ -0,0 +1,284 @@ +package net + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "syscall" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + IpUnicastIf = 31 + Ipv6UnicastIf = 31 + + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options + Ipv6V6only = 27 +) + +// GetBestInterfaceFunc is set at runtime to avoid import cycle +var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error) + +// nativeToBigEndian converts a uint32 from native byte order to big-endian +func nativeToBigEndian(v uint32) uint32 { + return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24 +} + +// parseDestinationAddress parses the destination address from various formats +func parseDestinationAddress(network, address string) (netip.Addr, error) { + if address == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + if addrPort, err := netip.ParseAddrPort(address); err == nil { + return addrPort.Addr(), nil + } + + if dest, err := netip.ParseAddr(address); err == nil { + return dest, nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + // No port, treat whole string as host + host = address + } + + if host == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil || len(ips) == 0 { + return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err) + } + + dest, ok := netip.AddrFromSlice(ips[0].IP) + if !ok { + return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP) + } + + if ips[0].Zone != "" { + dest = dest.WithZone(ips[0].Zone) + } + + return dest, nil +} + +func getInterfaceFromZone(zone string) *net.Interface { + if zone == "" { + return nil + } + + idx, err := strconv.Atoi(zone) + if err != nil { + log.Debugf("invalid zone format for Windows (expected numeric): %s", zone) + return nil + } + + iface, err := net.InterfaceByIndex(idx) + if err != nil { + log.Debugf("failed to get interface by index %d from zone: %v", idx, err) + return nil + } + + return iface +} + +type interfaceSelection struct { + iface4 *net.Interface + iface6 *net.Interface +} + +func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection { + iface := getInterfaceFromZone(zone) + if iface == nil { + return nil + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface} + } + return &interfaceSelection{iface4: iface} +} + +func selectInterfaceForUnspecified() (*interfaceSelection, error) { + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + var result interfaceSelection + vpnIfaceName := GetVPNInterfaceName() + + if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil { + result.iface4 = iface4 + } else { + log.Debugf("No IPv4 default route found: %v", err) + } + + if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil { + result.iface6 = iface6 + } else { + log.Debugf("No IPv6 default route found: %v", err) + } + + if result.iface4 == nil && result.iface6 == nil { + return nil, errors.New("no default routes found") + } + + return &result, nil +} + +func selectInterface(dest netip.Addr) (*interfaceSelection, error) { + if zone := dest.Zone(); zone != "" { + if selection := selectInterfaceForZone(dest, zone); selection != nil { + return selection, nil + } + } + + if dest.IsUnspecified() { + return selectInterfaceForUnspecified() + } + + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName()) + if err != nil { + return nil, fmt.Errorf("find route for %s: %w", dest, err) + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface}, nil + } + return &interfaceSelection{iface4: iface}, nil +} + +func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error { + ifaceIndexBE := nativeToBigEndian(uint32(iface.Index)) + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil { + return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error { + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil { + return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error { + // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.) + // Never generic ones (udp, tcp, ip) + + switch { + case strings.HasSuffix(network, "4"): + // IPv4-only socket (udp4, tcp4, ip4) + return setUnicastIfIPv4(fd, network, selection, address) + + case strings.HasSuffix(network, "6"): + // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only + return setUnicastIfIPv6(fd, network, selection, address) + } + + // Shouldn't reach here based on Go's documented behavior + return fmt.Errorf("unexpected network type: %s", network) +} + +func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error { + if selection.iface4 == nil { + return nil + } + + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address) + return nil +} + +func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error { + isDualStack := checkDualStack(fd) + + // For dual-stack sockets, also set the IPv4 option + if isDualStack && selection.iface4 != nil { + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address) + } + + if selection.iface6 == nil { + return nil + } + + if err := setIPv6UnicastIF(fd, selection.iface6); err != nil { + return err + } + + log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address) + return nil +} + +func checkDualStack(fd uintptr) bool { + var v6Only int + v6OnlyLen := int32(unsafe.Sizeof(v6Only)) + err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen) + return err == nil && v6Only == 0 +} + +// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address +func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + dest, err := parseDestinationAddress(network, address) + if err != nil { + return err + } + + dest = dest.Unmap() + + if !dest.IsValid() { + return fmt.Errorf("invalid destination address for %s", address) + } + + selection, err := selectInterface(dest) + if err != nil { + return err + } + + var controlErr error + err = c.Control(func(fd uintptr) { + controlErr = setUnicastIf(fd, network, selection, address) + }) + + if err != nil { + return fmt.Errorf("control: %w", err) + } + + return controlErr +} diff --git a/util/net/protectsocket_android.go b/client/net/protectsocket_android.go similarity index 100% rename from util/net/protectsocket_android.go rename to client/net/protectsocket_android.go diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 2422d2683..7c9fa021a 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -2,7 +2,7 @@ set -eEuo pipefail : ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() @@ -39,7 +39,7 @@ wait_for_message() { info "not waiting for log line ${message@Q} due to zero timeout." elif test -n "${log_file_path}"; then info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) else info "log file unsupported, sleeping for ${timeout} seconds..." sleep "${timeout}" @@ -81,7 +81,7 @@ wait_for_daemon_startup() { login_if_needed() { local timeout="${1}" - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then info "already logged in, skipping 'netbird up'..." else info "logging in..." diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 13cc8cdbc..0ea294c53 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v5.29.3 +// protoc v6.32.1 // source: daemon.proto package proto @@ -826,8 +826,10 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -874,6 +876,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } +func (x *StatusRequest) GetWaitForReady() bool { + if x != nil && x.WaitForReady != nil { + return *x.WaitForReady + } + return false +} + type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -4904,10 +4913,12 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"g\n" + + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + + "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + + "\r_waitForReady\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -5491,6 +5502,7 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index de62493af..2d904cb32 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -194,6 +194,8 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + optional bool waitForReady = 3; } message StatusResponse{ diff --git a/client/server/server.go b/client/server/server.go index 4b0c59e4d..864b2c506 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -65,6 +65,9 @@ type Server struct { mutex sync.Mutex config *profilemanager.Config proto.UnimplementedDaemonServiceServer + clientRunning bool // protected by mutex + clientRunningChan chan struct{} + clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -103,6 +106,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + + if s.clientRunning { + return nil + } + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -172,8 +180,10 @@ func (s *Server) Start() error { return nil } - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) - + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -204,12 +214,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - runningChan chan struct{}, -) { - backOff := getConnectWithBackoff(ctx) - retryStarted := false +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { + defer func() { + s.mutex.Lock() + s.clientRunning = false + s.mutex.Unlock() + }() + if s.config.DisableAutoConnect { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + log.Debugf("run client connection exited with error: %v", err) + } + log.Tracef("client connection exited") + return + } + + backOff := getConnectWithBackoff(ctx) go func() { t := time.NewTicker(24 * time.Hour) for { @@ -218,89 +238,36 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage t.Stop() return case <-t.C: - if retryStarted { - - mgmtState := statusRecorder.GetManagementState() - signalState := statusRecorder.GetSignalState() - if mgmtState.Connected && signalState.Connected { - log.Tracef("resetting status") - retryStarted = false - } else { - log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) - } + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + backOff.Reset() + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) } } } }() runOperation := func() error { - log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - - err := s.connectClient.Run(runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + return err } - if config.DisableAutoConnect { - return backoff.Permanent(err) - } - - if !retryStarted { - retryStarted = true - backOff.Reset() - } - - log.Tracef("client connection exited") - return fmt.Errorf("client connection exited") + log.Tracef("client connection exited gracefully, do not need to retry") + return nil } - err := backoff.Retry(runOperation, backOff) - if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { - log.Errorf("received an error when trying to connect: %v", err) - } else { - log.Tracef("retry canceled") - } -} - -// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries -func getConnectWithBackoff(ctx context.Context) backoff.BackOff { - initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) - maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) - maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) - multiplier := defaultRetryMultiplier - - if envValue := os.Getenv(retryMultiplierVar); envValue != "" { - // parse the multiplier from the environment variable string value to float64 - value, err := strconv.ParseFloat(envValue, 64) - if err != nil { - log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) - } else { - multiplier = value - } + if err := backoff.Retry(runOperation, backOff); err != nil { + log.Errorf("operation failed: %v", err) } - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: initialInterval, - RandomizationFactor: 1, - Multiplier: multiplier, - MaxInterval: maxInterval, - MaxElapsedTime: maxElapsedTime, // 14 days - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// parseEnvDuration parses the environment variable and returns the duration -func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { - if envValue := os.Getenv(envVar); envValue != "" { - if duration, err := time.ParseDuration(envValue); err == nil { - return duration - } - log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + if giveUpChan != nil { + close(giveUpChan) } - return defaultDuration } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -423,7 +390,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(s.rootCtx) + ctx, cancel := context.WithCancel(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -433,11 +400,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(ctx) + state := internal.CtxGetState(s.rootCtx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -650,6 +617,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() + if s.clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return nil, err + } + if status == internal.StatusNeedsLogin { + s.actCancel() + } + s.mutex.Unlock() + + return s.waitForUp(callerCtx) + } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -665,16 +646,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } + if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but . + // it should be nil here, but in case it isn't we cancel it. if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) - md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -717,23 +698,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + return s.waitForUp(callerCtx) +} + +// todo: handle potential race conditions +func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) - for { - select { - case <-runningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() - } + select { + case <-s.clientGiveUpChan: + return nil, fmt.Errorf("client gave up to connect") + case <-s.clientRunningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } @@ -1007,12 +996,46 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - s.mutex.Lock() - defer s.mutex.Unlock() + clientRunning := s.clientRunning + s.mutex.Unlock() + + if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + return nil, err + } + + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + loop: + for { + select { + case <-s.clientGiveUpChan: + ticker.Stop() + break loop + case <-s.clientRunningChan: + ticker.Stop() + break loop + case <-ticker.C: + status, err := state.Status() + if err != nil { + continue + } + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + continue + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { @@ -1194,6 +1217,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p }, nil } +// AddProfile adds a new profile to the daemon. +func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + + if msg.ProfileName == "" || msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") + } + + if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to create profile: %v", err) + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + return &proto.AddProfileResponse{}, nil +} + +// RemoveProfile removes a profile from the daemon. +func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { + return nil, err + } + + if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { + log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) + } + + if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to remove profile: %v", err) + return nil, fmt.Errorf("failed to remove profile: %w", err) + } + + return &proto.RemoveProfileResponse{}, nil +} + +// ListProfiles lists all profiles in the daemon. +func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") + } + + profiles, err := s.profileManager.ListProfiles(msg.Username) + if err != nil { + log.Errorf("failed to list profiles: %v", err) + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + response := &proto.ListProfilesResponse{ + Profiles: make([]*proto.Profile, len(profiles)), + } + for i, profile := range profiles { + response.Profiles[i] = &proto.Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + } + } + + return response, nil +} + +// GetActiveProfile returns the active profile in the daemon. +func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + activeProfile, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + return &proto.GetActiveProfileResponse{ + ProfileName: activeProfile.Name, + Username: activeProfile.Username, + }, nil +} + +// GetFeatures returns the features supported by the daemon. +func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + features := &proto.GetFeaturesResponse{ + DisableProfiles: s.checkProfilesDisabled(), + DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + } + + return features, nil +} + +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { + log.Tracef("running client connection") + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) + if err := s.connectClient.Run(runningChan); err != nil { + return err + } + return nil +} + +func (s *Server) checkProfilesDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.profilesDisabled { + return true + } + + return false +} + +func (s *Server) checkUpdateSettingsDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.updateSettingsDisabled { + return true + } + + return false +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() @@ -1205,6 +1356,45 @@ func (s *Server) onSessionExpire() { } } +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus := proto.FullStatus{ ManagementState: &proto.ManagementState{}, @@ -1320,121 +1510,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// AddProfile adds a new profile to the daemon. -func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.checkProfilesDisabled() { - return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) - } - - if msg.ProfileName == "" || msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") - } - - if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to create profile: %v", err) - return nil, fmt.Errorf("failed to create profile: %w", err) - } - - return &proto.AddProfileResponse{}, nil -} - -// RemoveProfile removes a profile from the daemon. -func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { - return nil, err - } - - if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { - log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) - } - - if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to remove profile: %v", err) - return nil, fmt.Errorf("failed to remove profile: %w", err) - } - - return &proto.RemoveProfileResponse{}, nil -} - -// ListProfiles lists all profiles in the daemon. -func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") - } - - profiles, err := s.profileManager.ListProfiles(msg.Username) - if err != nil { - log.Errorf("failed to list profiles: %v", err) - return nil, fmt.Errorf("failed to list profiles: %w", err) - } - - response := &proto.ListProfilesResponse{ - Profiles: make([]*proto.Profile, len(profiles)), - } - for i, profile := range profiles { - response.Profiles[i] = &proto.Profile{ - Name: profile.Name, - IsActive: profile.IsActive, - } - } - - return response, nil -} - -// GetActiveProfile returns the active profile in the daemon. -func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - activeProfile, err := s.profileManager.GetActiveProfileState() - if err != nil { - log.Errorf("failed to get active profile state: %v", err) - return nil, fmt.Errorf("failed to get active profile state: %w", err) - } - - return &proto.GetActiveProfileResponse{ - ProfileName: activeProfile.Name, - Username: activeProfile.Username, - }, nil -} - -// GetFeatures returns the features supported by the daemon. -func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - features := &proto.GetFeaturesResponse{ - DisableProfiles: s.checkProfilesDisabled(), - DisableUpdateSettings: s.checkUpdateSettingsDisabled(), - } - - return features, nil -} - -func (s *Server) checkProfilesDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.profilesDisabled { - return true - } - - return false -} - -func (s *Server) checkUpdateSettingsDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.updateSettingsDisabled { - return true - } - - return false -} diff --git a/client/server/server_test.go b/client/server/server_test.go index 24ff9fb0c..755925003 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,25 +10,25 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/groups" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) { profName := "default" + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: u.String(), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) - err = s.Start() require.NoError(t, err) - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - s.config = &profilemanager.Config{ - ManagementURL: u, - } - upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) + log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } @@ -294,15 +292,20 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + permissionsManagerMock := permissions.NewMockManager(ctrl) + peersManager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - permissionsManagerMock := permissions.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) diff --git a/client/system/info.go b/client/system/info.go index 90abf864b..1e4342e34 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/shared/management/proto" @@ -114,14 +115,6 @@ func (i *Info) SetFlags( } } -// StaticInfo is an object that contains machine information that does not change -type StaticInfo struct { - SystemSerialNumber string - SystemProductName string - SystemManufacturer string - Environment Environment -} - // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) @@ -199,6 +192,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { + log.Debugf("gathering system information with checks: %d", len(checks)) processCheckPaths := make([]string, 0) for _, check := range checks { processCheckPaths = append(processCheckPaths, check.GetFiles()...) @@ -208,16 +202,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro if err != nil { return nil, err } + log.Debugf("gathering process check information completed") info := GetInfo(ctx) info.Files = files + log.Debugf("all system information gathered successfully") return info, nil } - -// UpdateStaticInfo asynchronously updates static system and platform information -func UpdateStaticInfo() { - go func() { - _ = updateStaticInfo() - }() -} diff --git a/client/system/info_android.go b/client/system/info_android.go index 56fe0741d..78895bfa8 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -15,6 +15,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { kernel := "android" diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index f105ada60..caa344737 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -19,6 +19,10 @@ import ( "github.com/netbirdio/netbird/version" ) +func UpdateStaticInfoAsync() { + go updateStaticInfo() +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { utsname := unix.Utsname{} @@ -41,7 +45,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index bed6711de..8e1353151 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -18,6 +18,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { out := _getInfo() diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 897ec0a35..705c37920 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -10,6 +10,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { diff --git a/client/system/info_linux.go b/client/system/info_linux.go index 9bfc82009..6c7a23b95 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -23,6 +23,10 @@ var ( getSystemInfo = defaultSysInfoImplementation ) +func UpdateStaticInfoAsync() { + go updateStaticInfo() +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { info := _getInfo() @@ -48,7 +52,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 6f05ded20..d7f8f30aa 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -2,187 +2,51 @@ package system import ( "context" - "fmt" "os" "runtime" - "strings" "time" log "github.com/sirupsen/logrus" - "github.com/yusufpapurcu/wmi" - "golang.org/x/sys/windows/registry" "github.com/netbirdio/netbird/version" ) -type Win32_OperatingSystem struct { - Caption string -} - -type Win32_ComputerSystem struct { - Manufacturer string -} - -type Win32_ComputerSystemProduct struct { - Name string -} - -type Win32_BIOS struct { - SerialNumber string +func UpdateStaticInfoAsync() { + go updateStaticInfo() } // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - osName, osVersion := getOSNameAndVersion() - buildVersion := getBuildVersion() - - addrs, err := networkAddresses() - if err != nil { - log.Warnf("failed to discover network addresses: %s", err) - } - start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ Kernel: "windows", - OSVersion: osVersion, + OSVersion: si.OSVersion, Platform: "unknown", - OS: osName, + OS: si.OSName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), - KernelVersion: buildVersion, - NetworkAddresses: addrs, + KernelVersion: si.BuildVersion, SystemSerialNumber: si.SystemSerialNumber, SystemProductName: si.SystemProductName, SystemManufacturer: si.SystemManufacturer, Environment: si.Environment, } + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } else { + gio.NetworkAddresses = addrs + } + systemHostname, _ := os.Hostname() gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) - return gio } - -func sysInfo() (serialNumber string, productName string, manufacturer string) { - var err error - serialNumber, err = sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - productName, err = sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err = sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - return serialNumber, productName, manufacturer -} - -func getOSNameAndVersion() (string, string) { - var dst []Win32_OperatingSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - log.Error(err) - return "Windows", getBuildVersion() - } - - if len(dst) == 0 { - return "Windows", getBuildVersion() - } - - split := strings.Split(dst[0].Caption, " ") - - if len(split) <= 3 { - return "Windows", getBuildVersion() - } - - name := split[1] - version := split[2] - if split[2] == "Server" { - name = fmt.Sprintf("%s %s", split[1], split[2]) - version = split[3] - } - - return name, version -} - -func getBuildVersion() string { - k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) - if err != nil { - log.Error(err) - return "0.0.0.0" - } - defer func() { - deferErr := k.Close() - if deferErr != nil { - log.Error(deferErr) - } - }() - - major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") - if err != nil { - log.Error(err) - } - minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") - if err != nil { - log.Error(err) - } - build, _, err := k.GetStringValue("CurrentBuildNumber") - if err != nil { - log.Error(err) - } - // Update Build Revision - ubr, _, err := k.GetIntegerValue("UBR") - if err != nil { - log.Error(err) - } - ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) - return ver -} - -func sysNumber() (string, error) { - var dst []Win32_BIOS - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].SerialNumber, nil -} - -func sysProductName() (string, error) { - var dst []Win32_ComputerSystemProduct - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - // `ComputerSystemProduct` could be empty on some virtualized systems - if len(dst) < 1 { - return "unknown", nil - } - return dst[0].Name, nil -} - -func sysManufacturer() (string, error) { - var dst []Win32_ComputerSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].Manufacturer, nil -} diff --git a/client/system/static_info.go b/client/system/static_info.go index f178ec932..12a2663a1 100644 --- a/client/system/static_info.go +++ b/client/system/static_info.go @@ -3,12 +3,7 @@ package system import ( - "context" "sync" - "time" - - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" ) var ( @@ -16,25 +11,26 @@ var ( once sync.Once ) -func updateStaticInfo() StaticInfo { +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment + + // Windows specific fields + OSName string + OSVersion string + BuildVersion string +} + +func updateStaticInfo() { once.Do(func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() - wg.Done() - }() - go func() { - staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) - wg.Done() - }() - go func() { - staticInfo.Environment.Platform = detect_platform.Detect(ctx) - wg.Done() - }() - wg.Wait() + staticInfo = newStaticInfo() }) +} + +func getStaticInfo() StaticInfo { + updateStaticInfo() return staticInfo } diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go deleted file mode 100644 index faa3e700b..000000000 --- a/client/system/static_info_stub.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build android || freebsd || ios - -package system - -// updateStaticInfo returns an empty implementation for unsupported platforms -func updateStaticInfo() StaticInfo { - return StaticInfo{} -} diff --git a/client/system/static_info_update.go b/client/system/static_info_update.go new file mode 100644 index 000000000..af8b1e266 --- /dev/null +++ b/client/system/static_info_update.go @@ -0,0 +1,35 @@ +//go:build (linux && !android) || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +func newStaticInfo() StaticInfo { + si := StaticInfo{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + si.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + si.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + return si +} diff --git a/client/system/static_info_update_windows.go b/client/system/static_info_update_windows.go new file mode 100644 index 000000000..5f232c1de --- /dev/null +++ b/client/system/static_info_update_windows.go @@ -0,0 +1,184 @@ +package system + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "github.com/yusufpapurcu/wmi" + "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +type Win32_OperatingSystem struct { + Caption string +} + +type Win32_ComputerSystem struct { + Manufacturer string +} + +type Win32_ComputerSystemProduct struct { + Name string +} + +type Win32_BIOS struct { + SerialNumber string +} + +func newStaticInfo() StaticInfo { + si := StaticInfo{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() + wg.Done() + }() + wg.Add(1) + go func() { + si.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + wg.Add(1) + go func() { + si.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Add(1) + go func() { + si.OSName, si.OSVersion = getOSNameAndVersion() + wg.Done() + }() + wg.Add(1) + go func() { + si.BuildVersion = getBuildVersion() + wg.Done() + }() + wg.Wait() + return si +} + +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + +func sysNumber() (string, error) { + var dst []Win32_BIOS + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].SerialNumber, nil +} + +func sysProductName() (string, error) { + var dst []Win32_ComputerSystemProduct + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + // `ComputerSystemProduct` could be empty on some virtualized systems + if len(dst) < 1 { + return "unknown", nil + } + return dst[0].Name, nil +} + +func sysManufacturer() (string, error) { + var dst []Win32_ComputerSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].Manufacturer, nil +} + +func getOSNameAndVersion() (string, string) { + var dst []Win32_OperatingSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + log.Error(err) + return "Windows", getBuildVersion() + } + + if len(dst) == 0 { + return "Windows", getBuildVersion() + } + + split := strings.Split(dst[0].Caption, " ") + + if len(split) <= 3 { + return "Windows", getBuildVersion() + } + + name := split[1] + version := split[2] + if split[2] == "Server" { + name = fmt.Sprintf("%s %s", split[1], split[2]) + version = split[3] + } + + return name, version +} + +func getBuildVersion() string { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) + if err != nil { + log.Error(err) + return "0.0.0.0" + } + defer func() { + deferErr := k.Close() + if deferErr != nil { + log.Error(deferErr) + } + }() + + major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") + if err != nil { + log.Error(err) + } + minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") + if err != nil { + log.Error(err) + } + build, _, err := k.GetStringValue("CurrentBuildNumber") + if err != nil { + log.Error(err) + } + // Update Build Revision + ubr, _, err := k.GetIntegerValue("UBR") + if err != nil { + log.Error(err) + } + ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) + return ver +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 78dd696db..533bf23d3 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -469,6 +469,17 @@ func (s *serviceClient) getConnectionForm() *widget.Form { } func (s *serviceClient) saveSettings() { + // Check if update settings are disabled by daemon + features, err := s.getFeatures() + if err != nil { + log.Errorf("failed to get features from daemon: %v", err) + // Continue with default behavior if features can't be retrieved + } else if features != nil && features.DisableUpdateSettings { + log.Warn("Configuration updates are disabled by daemon") + dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings) + return + } + if err := s.validateSettings(); err != nil { dialog.ShowError(err, s.wSettings) return @@ -605,6 +616,28 @@ func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error { return fmt.Errorf("set config: %w", err) } + // Reconnect if connected to apply the new settings + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + log.Errorf("get service status: %v", err) + return + } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + return + } + } + }() + return nil } @@ -637,7 +670,7 @@ func (s *serviceClient) getNetworkForm() *widget.Form { {Text: "Disable DNS", Widget: s.sDisableDNS}, {Text: "Disable Client Routes", Widget: s.sDisableClientRoutes}, {Text: "Disable Server Routes", Widget: s.sDisableServerRoutes}, - {Text: "Block LAN Access", Widget: s.sBlockLANAccess}, + {Text: "Disable LAN Access", Widget: s.sBlockLANAccess}, }, } } diff --git a/flow/client/client.go b/flow/client/client.go index 949824065..603fd6882 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -20,9 +20,9 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) type GRPCClient struct { diff --git a/go.mod b/go.mod index 4b9064dbc..f135b5fc6 100644 --- a/go.mod +++ b/go.mod @@ -12,19 +12,18 @@ require ( github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 - github.com/pion/ice/v3 v3.0.2 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.39.0 - golang.org/x/sys v0.33.0 + golang.org/x/crypto v0.40.0 + golang.org/x/sys v0.34.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.64.1 - google.golang.org/protobuf v1.36.6 + google.golang.org/grpc v1.73.0 + google.golang.org/protobuf v1.36.8 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -63,16 +62,18 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 - github.com/pion/logging v0.2.2 + github.com/pion/ice/v4 v4.0.0-00010101000000-000000000000 + github.com/pion/logging v0.2.4 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 - github.com/pion/transport/v3 v3.0.1 + github.com/pion/stun/v3 v3.0.0 + github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.22.0 @@ -94,18 +95,18 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 - go.opentelemetry.io/otel v1.26.0 + go.opentelemetry.io/otel v1.35.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.26.0 - go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.opentelemetry.io/otel/metric v1.35.0 + go.opentelemetry.io/otel/sdk/metric v1.35.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.40.0 - golang.org/x/oauth2 v0.27.0 - golang.org/x/sync v0.15.0 - golang.org/x/term v0.32.0 + golang.org/x/net v0.42.0 + golang.org/x/oauth2 v0.28.0 + golang.org/x/sync v0.16.0 + golang.org/x/term v0.33.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -118,7 +119,7 @@ require ( require ( cloud.google.com/go/auth v0.3.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.3.0 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect @@ -214,8 +215,10 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect - github.com/pion/mdns v0.0.12 // indirect + github.com/pion/dtls/v3 v3.0.7 // indirect + github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/transport/v2 v2.2.4 // indirect + github.com/pion/turn/v4 v4.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect @@ -231,22 +234,23 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect - go.opentelemetry.io/otel/sdk v1.26.0 // indirect - go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.opentelemetry.io/otel/sdk v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.25.0 // indirect - golang.org/x/text v0.26.0 // indirect + golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.33.0 // indirect + golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) @@ -259,6 +263,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e +replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index f3a9a1788..f2c51bf67 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= @@ -502,10 +502,10 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= -github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -547,21 +547,29 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= +github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= +github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= -github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= +github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= +github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= +github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= -github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= +github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= +github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -591,8 +599,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0 github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= @@ -684,6 +692,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -715,26 +725,28 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= -go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= -go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= -go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= -go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= -go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= -go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= -go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= -go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= -go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -766,8 +778,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -866,8 +878,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -881,8 +893,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= -golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -900,8 +912,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -973,8 +985,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -987,8 +999,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= +golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= +golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1005,8 +1017,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1071,8 +1083,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= -golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1155,10 +1167,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM= +google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1179,8 +1192,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1195,11 +1208,10 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2d7c65cbe..cfec1000e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -328,6 +328,45 @@ delete_auto_service_user() { echo "$PARSED_RESPONSE" } +delete_default_zitadel_admin() { + INSTANCE_URL=$1 + PAT=$2 + + # Search for the default zitadel-admin user + RESPONSE=$( + curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + -d '{ + "queries": [ + { + "userNameQuery": { + "userName": "zitadel-admin@", + "method": "TEXT_QUERY_METHOD_STARTS_WITH" + } + } + ] + }' + ) + + DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty') + + if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then + echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID" + + RESPONSE=$( + curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + ) + PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"') + handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE" + + else + echo "Default zitadel-admin user not found: $RESPONSE" + fi +} + init_zitadel() { echo -e "\nInitializing Zitadel with NetBird's applications\n" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" @@ -346,6 +385,9 @@ init_zitadel() { echo -n "Waiting for Zitadel to become ready " wait_api "$INSTANCE_URL" "$PAT" + echo "Deleting default zitadel-admin user..." + delete_default_zitadel_admin "$INSTANCE_URL" "$PAT" + # create the zitadel project echo "Creating new zitadel project" PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") diff --git a/management/README.md b/management/README.md index 1122a9e76..c70285d43 100644 --- a/management/README.md +++ b/management/README.md @@ -111,3 +111,6 @@ Generate gRpc code: #!/bin/bash protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=. ``` + + + diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index b351f3bc9..984a56a39 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -20,7 +20,11 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore()) + integratedPeerValidator, err := integrations.NewIntegratedValidator( + context.Background(), + s.PeersManager(), + s.SettingsManager(), + s.EventStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index f217eadb3..ee9f294a4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -104,6 +104,8 @@ type DefaultAccountManager struct { accountUpdateLocks sync.Map updateAccountPeersBufferInterval atomic.Int64 + loginFilter *loginFilter + disableDefaultPolicy bool } @@ -211,6 +213,7 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, + loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, } @@ -1133,7 +1136,18 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID - err := am.Store.SaveUser(ctx, newUser) + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID) + if err != nil { + return "", err + } + + if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired { + newUser.Blocked = true + newUser.PendingApproval = true + } + + err = am.Store.SaveUser(ctx, newUser) if err != nil { return "", err } @@ -1143,7 +1157,11 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, return "", err } - am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) + if newUser.PendingApproval { + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true}) + } else { + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) + } return domainAccountID, nil } @@ -1612,6 +1630,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } +func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { + return am.loginFilter.allowLogin(wgPubKey, metahash) +} + func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { start := time.Now() defer func() { @@ -1628,6 +1650,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } + metahash := metaHash(meta, realIP.String()) + am.loginFilter.addLogin(peerPubKey, metahash) + return peer, netMap, postureChecks, nil } @@ -1636,7 +1661,6 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } - return nil } @@ -1690,7 +1714,9 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err) continue } - peers = append(peers, peer) + if peer.UserID != "" { + peers = append(peers, peer) + } } if len(peers) > 0 { err := am.expireAndUpdatePeers(ctx, accountID, peers) @@ -1786,6 +1812,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, + Extra: &types.ExtraSettings{ + UserApprovalRequired: true, + }, }, Onboarding: types.AccountOnboarding{ OnboardingFlowPending: true, @@ -1892,6 +1921,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, + Extra: &types.ExtraSettings{ + UserApprovalRequired: true, + }, }, } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index c7a39004a..30fbbbc3e 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -32,6 +32,8 @@ type Manager interface { DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) + RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) @@ -123,4 +125,5 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 252be23f7..81a921bf9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/prometheus/client_golang/prometheus/push" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -3046,19 +3048,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3121,19 +3118,14 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3196,24 +3188,44 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } } +func TestMain(m *testing.M) { + exitCode := m.Run() + + if exitCode == 0 && os.Getenv("CI") == "true" { + runID := os.Getenv("GITHUB_RUN_ID") + storeEngine := os.Getenv("NETBIRD_STORE_ENGINE") + err := push.New("http://localhost:9091", "account_manager_benchmark"). + Collector(testing_tools.BenchmarkDuration). + Grouping("ci_run", runID). + Grouping("store_engine", storeEngine). + Push() + if err != nil { + log.Printf("Failed to push metrics: %v", err) + } else { + time.Sleep(1 * time.Minute) + _ = push.New("http://localhost:9091", "account_manager_benchmark"). + Grouping("ci_run", runID). + Grouping("store_engine", storeEngine). + Delete() + } + } + + os.Exit(exitCode) +} + func Test_GetCreateAccountByPrivateDomain(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -3594,3 +3606,93 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { require.Error(t, err, "should fail with invalid peer ID") }) } + +func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create a domain-based account with user approval enabled + existingAccountID := "existing-account" + account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false) + account.Settings.Extra = &types.ExtraSettings{ + UserApprovalRequired: true, + } + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Set the account as domain primary account + account.IsDomainPrimaryAccount = true + account.DomainCategory = types.PrivateCategory + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Test adding new user to existing account with approval required + newUserID := "new-user-id" + userAuth := nbcontext.UserAuth{ + UserId: newUserID, + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + + acc, err := manager.Store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain") + require.Equal(t, "example.com", acc.Domain, "Account domain should match") + + returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) + require.NoError(t, err) + require.Equal(t, existingAccountID, returnedAccountID) + + // Verify user was created with pending approval + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) + require.NoError(t, err) + assert.True(t, user.Blocked, "User should be blocked when approval is required") + assert.True(t, user.PendingApproval, "User should be pending approval") + assert.Equal(t, existingAccountID, user.AccountID) +} + +func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create a domain-based account without user approval + ownerUserAuth := nbcontext.UserAuth{ + UserId: "owner-user", + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth) + require.NoError(t, err) + + // Modify the account to disable user approval + account, err := manager.Store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + account.Settings.Extra = &types.ExtraSettings{ + UserApprovalRequired: false, + } + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Test adding new user to existing account without approval required + newUserID := "new-user-id" + userAuth := nbcontext.UserAuth{ + UserId: newUserID, + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + + returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) + require.NoError(t, err) + require.Equal(t, existingAccountID, returnedAccountID) + + // Verify user was created without pending approval + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) + require.NoError(t, err) + assert.False(t, user.Blocked, "User should not be blocked when approval is not required") + assert.False(t, user.PendingApproval, "User should not be pending approval") + assert.Equal(t, existingAccountID, user.AccountID) +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 6f9619597..5c5989f84 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -177,6 +177,8 @@ const ( AccountNetworkRangeUpdated Activity = 87 PeerIPUpdated Activity = 88 + UserApproved Activity = 89 + UserRejected Activity = 90 AccountDeleted Activity = 99999 ) @@ -284,6 +286,8 @@ var activityMap = map[Activity]Code{ AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, + UserApproved: {"User approved", "user.approve"}, + UserRejected: {"User rejected", "user.reject"}, } // StringCode returns a string code of the activity diff --git a/management/server/group.go b/management/server/group.go index 86bc0d8a0..487cb6d97 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -202,35 +202,45 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { + var globalErr error + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) + + if err = transaction.CreateGroup(ctx, newGroup); err != nil { + return err + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return err + } + groupIDs = append(groupIDs, newGroup.ID) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + return nil + }) if err != nil { - return err + log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) + if len(groupIDs) == 1 { + return err + } + globalErr = errors.Join(globalErr, err) + // continue updating other groups } + } - if err = transaction.CreateGroups(ctx, accountID, groupsToSave); err != nil { - return err - } - - return transaction.IncrementNetworkSerial(ctx, accountID) - }) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err } @@ -243,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return nil + return globalErr } // UpdateGroups updates groups in the account. @@ -260,35 +270,45 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { + var globalErr error + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) - groupIDs = append(groupIDs, newGroup.ID) + + if err = transaction.UpdateGroup(ctx, newGroup); err != nil { + return err + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return err + } events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + groupIDs = append(groupIDs, newGroup.ID) + + return nil + }) if err != nil { - return err + log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) + if len(groups) == 1 { + return err + } + globalErr = errors.Join(globalErr, err) + // continue updating other groups } + } - if err = transaction.UpdateGroups(ctx, accountID, groupsToSave); err != nil { - return err - } - - return transaction.IncrementNetworkSerial(ctx, accountID) - }) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err } @@ -301,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return nil + return globalErr } // prepareGroupEvents prepares a list of event functions to be stored. @@ -584,13 +604,6 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st newGroup.ID = xid.New().String() } - for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) - if err != nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } - } - return nil } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index a637cf02d..60a00207e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -2,9 +2,11 @@ package server import ( "context" + "errors" "fmt" "net" "net/netip" + "os" "strings" "sync" "time" @@ -38,20 +40,28 @@ import ( internalStatus "github.com/netbirdio/netbird/shared/management/status" ) +const ( + envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" + envBlockPeers = "NB_BLOCK_SAME_PEERS" +) + // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager - peerLocks sync.Map - authManager auth.Manager - integratedPeerValidator integrated_validator.IntegratedValidator + peersUpdateManager *PeersUpdateManager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager + peerLocks sync.Map + authManager auth.Manager + + logBlockedPeers bool + blockPeersWithSameConfig bool + integratedPeerValidator integrated_validator.IntegratedValidator } // NewServer creates a new Management server @@ -82,18 +92,23 @@ func NewServer( } } + logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" + blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + return &GRPCServer{ wgKey: key, // peerKey -> event channel - peersUpdateManager: peersUpdateManager, - accountManager: accountManager, - settingsManager: settingsManager, - config: config, - secretsManager: secretsManager, - authManager: authManager, - appMetrics: appMetrics, - ephemeralManager: ephemeralManager, - integratedPeerValidator: integratedPeerValidator, + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + settingsManager: settingsManager, + config: config, + secretsManager: secretsManager, + authManager: authManager, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + logBlockedPeers: logBlockedPeers, + blockPeersWithSameConfig: blockPeersWithSameConfig, + integratedPeerValidator: integratedPeerValidator, }, nil } @@ -136,9 +151,6 @@ func getRealIP(ctx context.Context) net.IP { // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { reqStart := time.Now() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequest() - } ctx := srv.Context() @@ -147,6 +159,25 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if err != nil { return err } + realIP := getRealIP(ctx) + sRealIP := realIP.String() + peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) + metahashed := metaHash(peerMeta, sRealIP) + if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() + } + if s.logBlockedPeers { + log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) + } + if s.blockPeersWithSameConfig { + return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) + } + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequest() + } // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) @@ -172,14 +203,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - realIP := getRealIP(ctx) - log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) return mapError(ctx, err) @@ -198,7 +228,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } unlock() @@ -228,6 +258,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -345,6 +376,9 @@ func mapError(ctx context.Context, err error) error { default: } } + if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) { + return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error()) + } log.WithContext(ctx).Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } @@ -436,16 +470,9 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() - defer func() { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) - } - }() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequest() - } realIP := getRealIP(ctx) - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + sRealIP := realIP.String() + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -453,6 +480,24 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, err } + peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) + metahashed := metaHash(peerMeta, sRealIP) + if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if s.logBlockedPeers { + log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) + } + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() + } + if s.blockPeersWithSameConfig { + return nil, internalStatus.ErrPeerAlreadyLoggedIn + } + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + //nolint ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) @@ -463,6 +508,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + defer func() { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) + } + }() + if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) @@ -483,7 +534,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), - Meta: extractPeerMeta(ctx, loginReq.GetMeta()), + Meta: peerMeta, UserID: userID, SetupKey: loginReq.GetSetupKey(), ConnectionIP: realIP, @@ -949,8 +1000,6 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (* return nil, mapError(ctx, err) } - s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) - log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start)) return &proto.Empty{}, nil diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 9f2afe29d..f1552d0ea 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -11,11 +11,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -198,6 +198,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.Extra != nil { settings.Extra = &types.ExtraSettings{ PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, @@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: settings.Extra.UserApprovalRequired, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsGroups: settings.Extra.FlowGroups, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 1dad33a6f..4b9b79fdc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -15,11 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initAccountsTestData(t *testing.T, account *types.Account) *handler { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 414c7b1b9..af501e151 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -14,11 +14,11 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // Handler is a handler that returns peers of the account @@ -354,7 +354,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - CreatedAt: peer.CreatedAt, + CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, Ip: peer.IP.String(), @@ -391,33 +391,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn } return &api.PeerBatch{ - CreatedAt: peer.CreatedAt, - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.GetLastLogin(), - LoginExpired: peer.Status.LoginExpired, - AccessiblePeersCount: accessiblePeersCount, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, - + CreatedAt: peer.CreatedAt, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.GetLastLogin(), + LoginExpired: peer.Status.LoginExpired, + AccessiblePeersCount: accessiblePeersCount, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, + Ephemeral: peer.Ephemeral, } } diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index bcd637db4..4e03e5e9b 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,11 +9,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" nbcontext "github.com/netbirdio/netbird/management/server/context" ) @@ -31,6 +31,8 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") addUsersTokensEndpoint(accountManager, router) } @@ -323,17 +325,76 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { } isCurrent := user.ID == currenUserID + return &api.User{ - Id: user.ID, - Name: user.Name, - Email: user.Email, - Role: user.Role, - AutoGroups: autoGroups, - Status: userStatus, - IsCurrent: &isCurrent, - IsServiceUser: &user.IsServiceUser, - IsBlocked: user.IsBlocked, - LastLogin: &user.LastLogin, - Issued: &user.Issued, + Id: user.ID, + Name: user.Name, + Email: user.Email, + Role: user.Role, + AutoGroups: autoGroups, + Status: userStatus, + IsCurrent: &isCurrent, + IsServiceUser: &user.IsServiceUser, + IsBlocked: user.IsBlocked, + LastLogin: &user.LastLogin, + Issued: &user.Issued, + PendingApproval: user.PendingApproval, } } + +// approveUser is a POST request to approve a user that is pending approval +func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + userResponse := toUserResponse(user, userAuth.UserId) + util.WriteJSONObject(r.Context(), w, userResponse) +} + +// rejectUser is a DELETE request to reject a user that is pending approval +func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index f7dc81919..e08004218 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -16,13 +16,13 @@ import ( "github.com/stretchr/testify/require" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -725,3 +725,133 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri } return modules } + +func TestApproveUserEndpoint(t *testing.T) { + adminUser := &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + AccountID: existingAccountID, + AutoGroups: []string{}, + } + + pendingUser := &types.User{ + Id: "pending-user", + Role: types.UserRoleUser, + AccountID: existingAccountID, + Blocked: true, + PendingApproval: true, + AutoGroups: []string{}, + } + + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestingUser *types.User + }{ + { + name: "approve user as admin should return 200", + expectedStatus: 200, + expectedBody: true, + requestingUser: adminUser, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{} + am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + approvedUserInfo := &types.UserInfo{ + ID: pendingUser.Id, + Email: "pending@example.com", + Name: "Pending User", + Role: string(pendingUser.Role), + AutoGroups: []string{}, + IsServiceUser: false, + IsBlocked: false, + PendingApproval: false, + LastLogin: time.Now(), + Issued: types.UserIssuedAPI, + } + return approvedUserInfo, nil + } + + handler := newHandler(am) + router := mux.NewRouter() + router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") + + req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) + require.NoError(t, err) + + userAuth := nbcontext.UserAuth{ + AccountId: existingAccountID, + UserId: tc.requestingUser.Id, + } + ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedBody { + var response api.User + err = json.Unmarshal(rr.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "pending-user", response.Id) + assert.False(t, response.IsBlocked) + assert.False(t, response.PendingApproval) + } + }) + } +} + +func TestRejectUserEndpoint(t *testing.T) { + adminUser := &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + AccountID: existingAccountID, + AutoGroups: []string{}, + } + + tt := []struct { + name string + expectedStatus int + requestingUser *types.User + }{ + { + name: "reject user as admin should return 200", + expectedStatus: 200, + requestingUser: adminUser, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{} + am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + return nil + } + + handler := newHandler(am) + router := mux.NewRouter() + router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE") + + req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) + require.NoError(t, err) + + userAuth := nbcontext.UserAuth{ + AccountId: existingAccountID, + UserId: tc.requestingUser.Id, + } + ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + }) + } +} diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 52737e4eb..3fe3fe809 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) const modulePeers = "peers" @@ -47,7 +48,7 @@ func BenchmarkUpdatePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -65,7 +66,7 @@ func BenchmarkUpdatePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) }) } } @@ -82,7 +83,7 @@ func BenchmarkGetOnePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -92,7 +93,7 @@ func BenchmarkGetOnePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) }) } } @@ -109,7 +110,7 @@ func BenchmarkGetAllPeers(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -119,7 +120,7 @@ func BenchmarkGetAllPeers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) }) } } @@ -136,7 +137,7 @@ func BenchmarkDeletePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -146,7 +147,7 @@ func BenchmarkDeletePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 9404c4ee4..36b226db0 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Map to store peers, groups, users, and setupKeys by name @@ -47,7 +48,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -69,7 +70,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) }) } } @@ -86,7 +87,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -109,7 +110,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) }) } } @@ -126,7 +127,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -136,7 +137,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) }) } } @@ -153,7 +154,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -163,7 +164,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) }) } } @@ -180,7 +181,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) b.ResetTimer() @@ -190,7 +191,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 844b3e7a6..2868a20bd 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -18,8 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) const moduleUsers = "users" @@ -46,7 +47,7 @@ func BenchmarkUpdateUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -71,7 +72,7 @@ func BenchmarkUpdateUser(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) }) } } @@ -84,18 +85,18 @@ func BenchmarkGetOneUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) }) } } @@ -110,18 +111,18 @@ func BenchmarkGetAllUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) }) } } @@ -136,7 +137,7 @@ func BenchmarkDeleteUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -147,7 +148,7 @@ func BenchmarkDeleteUsers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 9f04e3c24..1079de4aa 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -15,9 +15,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) func Test_SetupKeys_Create(t *testing.T) { @@ -287,7 +288,7 @@ func Test_SetupKeys_Create(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(user.name+" - "+tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -572,7 +573,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -751,7 +752,7 @@ func Test_SetupKeys_Get(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) @@ -903,7 +904,7 @@ func Test_SetupKeys_GetAll(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) @@ -1087,7 +1088,7 @@ func Test_SetupKeys_Delete(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go new file mode 100644 index 000000000..741f03f18 --- /dev/null +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -0,0 +1,137 @@ +package channel + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/management-integrations/integrations" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + http2 "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" +) + +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) + done := make(chan struct{}) + if validateUpdate { + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + } + + geoMock := &geolocation.Mock{} + validatorMock := server.MockIntegratedValidator{} + proxyController := integrations.NewController(store) + userManager := users.NewManager(store) + permissionsManager := permissions.NewManager(store) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + // @note this is required so that PAT's validate from store, but JWT's are mocked + authManager := auth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, + MarkPATUsedFunc: authManager.MarkPATUsed, + GetPATInfoFunc: authManager.GetPATInfo, + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + peersManager := peers.NewManager(store, permissionsManager) + + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { + userAuth := nbcontext.UserAuth{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + userAuth.UserId = token + userAuth.AccountId = "testAccountId" + userAuth.Domain = "test.com" + userAuth.DomainCategory = "private" + case "otherUserId": + userAuth.UserId = "otherUserId" + userAuth.AccountId = "otherAccountId" + userAuth.Domain = "other.com" + userAuth.DomainCategory = "private" + case "invalidToken": + return userAuth, nil, errors.New("invalid token") + } + + jwtToken := jwt.New(jwt.SigningMethodHS256) + return userAuth, jwtToken, nil +} diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 1b82b156e..b7a63b104 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -3,7 +3,6 @@ package testing_tools import ( "bytes" "context" - "errors" "fmt" "io" "net" @@ -14,32 +13,12 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/users" - - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/groups" - nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/networks" - "github.com/netbirdio/netbird/management/server/networks/resources" - "github.com/netbirdio/netbird/management/server/networks/routers" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" ) @@ -106,90 +85,6 @@ type PerformanceMetrics struct { MaxMsPerOpCICD float64 } -func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) - if err != nil { - t.Fatalf("Failed to create test store: %v", err) - } - t.Cleanup(cleanup) - - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - if err != nil { - t.Fatalf("Failed to create metrics: %v", err) - } - - peersUpdateManager := server.NewPeersUpdateManager(nil) - updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) - done := make(chan struct{}) - if validateUpdate { - go func() { - if expectedPeerUpdate != nil { - peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) - } else { - peerShouldNotReceiveUpdate(t, updMsg) - } - close(done) - }() - } - - geoMock := &geolocation.Mock{} - validatorMock := server.MockIntegratedValidator{} - proxyController := integrations.NewController(store) - userManager := users.NewManager(store) - permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) - if err != nil { - t.Fatalf("Failed to create manager: %v", err) - } - - // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ - ValidateAndParseTokenFunc: mockValidateAndParseToken, - EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, - MarkPATUsedFunc: authManager.MarkPATUsed, - GetPATInfoFunc: authManager.GetPATInfo, - } - - networksManagerMock := networks.NewManagerMock() - resourcesManagerMock := resources.NewManagerMock() - routersManagerMock := routers.NewManagerMock() - groupsManagerMock := groups.NewManagerMock() - peersManager := peers.NewManager(store, permissionsManager) - - apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) - if err != nil { - t.Fatalf("Failed to create API handler: %v", err) - } - - return apiHandler, am, done -} - -func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) { - t.Helper() - select { - case msg := <-updateMessage: - t.Errorf("Unexpected message received: %+v", msg) - case <-time.After(500 * time.Millisecond): - return - } -} - -func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { - t.Helper() - - select { - case msg := <-updateMessage: - if msg == nil { - t.Errorf("Received nil update message, expected valid message") - } - assert.Equal(t, expected, msg) - case <-time.After(500 * time.Millisecond): - t.Errorf("Timed out waiting for update message") - } -} - func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { t.Helper() @@ -222,11 +117,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta return content, expectedStatus == http.StatusOK } -func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { +func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) { b.Helper() ctx := context.Background() - account, err := am.GetAccount(ctx, TestAccountId) + acc, err := am.GetAccount(ctx, TestAccountId) if err != nil { b.Fatalf("Failed to get account: %v", err) } @@ -242,23 +137,23 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } - account.Peers[peer.ID] = peer + acc.Peers[peer.ID] = peer } // Create users for i := 0; i < users; i++ { user := &types.User{ Id: fmt.Sprintf("olduser-%d", i), - AccountID: account.Id, + AccountID: acc.Id, Role: types.UserRoleUser, } - account.Users[user.Id] = user + acc.Users[user.Id] = user } for i := 0; i < setupKeys; i++ { key := &types.SetupKey{ Id: fmt.Sprintf("oldkey-%d", i), - AccountID: account.Id, + AccountID: acc.Id, AutoGroups: []string{"someGroupID"}, UpdatedAt: time.Now().UTC(), ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), @@ -266,11 +161,11 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Type: "reusable", UsageLimit: 0, } - account.SetupKeys[key.Id] = key + acc.SetupKeys[key.Id] = key } // Create groups and policies - account.Policies = make([]*types.Policy, 0, groups) + acc.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &types.Group{ @@ -281,7 +176,7 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - account.Groups[groupID] = group + acc.Groups[groupID] = group // Create a policy for this group policy := &types.Policy{ @@ -301,10 +196,10 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro }, }, } - account.Policies = append(account.Policies, policy) + acc.Policies = append(acc.Policies, policy) } - account.PostureChecks = []*posture.Checks{ + acc.PostureChecks = []*posture.Checks{ { ID: "PostureChecksAll", Name: "All", @@ -316,52 +211,38 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro }, } - err = am.Store.SaveAccount(context.Background(), account) + store := am.GetStore() + + err = store.SaveAccount(context.Background(), acc) if err != nil { b.Fatalf("Failed to save account: %v", err) } } -func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { +func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { b.Helper() - branch := os.Getenv("GIT_BRANCH") - if branch == "" { - b.Fatalf("environment variable GIT_BRANCH is not set") - } - if recorder.Code != http.StatusOK { b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code) } + EvaluateBenchmarkResults(b, testCase, duration, module, operation) + +} + +func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) { + b.Helper() + + branch := os.Getenv("GIT_BRANCH") + if branch == "" && os.Getenv("CI") == "true" { + b.Fatalf("environment variable GIT_BRANCH is not set") + } + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch) gauge.Set(msPerOp) b.ReportMetric(msPerOp, "ms/op") - -} - -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} - - switch token { - case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": - userAuth.UserId = token - userAuth.AccountId = "testAccountId" - userAuth.Domain = "test.com" - userAuth.DomainCategory = "private" - case "otherUserId": - userAuth.UserId = "otherUserId" - userAuth.AccountId = "otherAccountId" - userAuth.Domain = "other.com" - userAuth.DomainCategory = "private" - case "invalidToken": - return userAuth, nil, errors.New("invalid token") - } - - jwtToken := jwt.New(jwt.SigningMethodHS256) - return userAuth, jwtToken, nil } diff --git a/management/server/loginfilter.go b/management/server/loginfilter.go new file mode 100644 index 000000000..8604af6e2 --- /dev/null +++ b/management/server/loginfilter.go @@ -0,0 +1,160 @@ +package server + +import ( + "hash/fnv" + "math" + "sync" + "time" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +const ( + reconnThreshold = 5 * time.Minute + baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit + reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban + metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer +) + +type lfConfig struct { + reconnThreshold time.Duration + baseBlockDuration time.Duration + reconnLimitForBan int + metaChangeLimit int +} + +func initCfg() *lfConfig { + return &lfConfig{ + reconnThreshold: reconnThreshold, + baseBlockDuration: baseBlockDuration, + reconnLimitForBan: reconnLimitForBan, + metaChangeLimit: metaChangeLimit, + } +} + +type loginFilter struct { + mu sync.RWMutex + cfg *lfConfig + logged map[string]*peerState +} + +type peerState struct { + currentHash uint64 + sessionCounter int + sessionStart time.Time + lastSeen time.Time + isBanned bool + banLevel int + banExpiresAt time.Time + metaChangeCounter int + metaChangeWindowStart time.Time +} + +func newLoginFilter() *loginFilter { + return newLoginFilterWithCfg(initCfg()) +} + +func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter { + return &loginFilter{ + logged: make(map[string]*peerState), + cfg: cfg, + } +} + +func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool { + l.mu.RLock() + defer func() { + l.mu.RUnlock() + }() + state, ok := l.logged[wgPubKey] + if !ok { + return true + } + if state.isBanned && time.Now().Before(state.banExpiresAt) { + return false + } + if metaHash != state.currentHash { + if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit { + return false + } + } + return true +} + +func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) { + now := time.Now() + l.mu.Lock() + defer func() { + l.mu.Unlock() + }() + + state, ok := l.logged[wgPubKey] + + if !ok { + l.logged[wgPubKey] = &peerState{ + currentHash: metaHash, + sessionCounter: 1, + sessionStart: now, + lastSeen: now, + metaChangeWindowStart: now, + metaChangeCounter: 1, + } + return + } + + if state.isBanned && now.After(state.banExpiresAt) { + state.isBanned = false + } + + if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) { + state.banLevel = 0 + } + + if metaHash != state.currentHash { + if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) { + state.metaChangeWindowStart = now + state.metaChangeCounter = 1 + } else { + state.metaChangeCounter++ + } + state.currentHash = metaHash + state.sessionCounter = 1 + state.sessionStart = now + state.lastSeen = now + return + } + + state.sessionCounter++ + if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold { + state.isBanned = true + state.banLevel++ + + backoffFactor := math.Pow(2, float64(state.banLevel-1)) + duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor) + state.banExpiresAt = now.Add(duration) + + state.sessionCounter = 0 + state.sessionStart = now + } + state.lastSeen = now +} + +func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 { + h := fnv.New64a() + + h.Write([]byte(meta.WtVersion)) + h.Write([]byte(meta.OSVersion)) + h.Write([]byte(meta.KernelVersion)) + h.Write([]byte(meta.Hostname)) + h.Write([]byte(meta.SystemSerialNumber)) + h.Write([]byte(pubip)) + + macs := uint64(0) + for _, na := range meta.NetworkAddresses { + for _, r := range na.Mac { + macs += uint64(r) + } + } + + return h.Sum64() + macs +} diff --git a/management/server/loginfilter_test.go b/management/server/loginfilter_test.go new file mode 100644 index 000000000..65782dd9d --- /dev/null +++ b/management/server/loginfilter_test.go @@ -0,0 +1,275 @@ +package server + +import ( + "hash/fnv" + "math" + "math/rand" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func testAdvancedCfg() *lfConfig { + return &lfConfig{ + reconnThreshold: 50 * time.Millisecond, + baseBlockDuration: 100 * time.Millisecond, + reconnLimitForBan: 3, + metaChangeLimit: 2, + } +} + +type LoginFilterTestSuite struct { + suite.Suite + filter *loginFilter +} + +func (s *LoginFilterTestSuite) SetupTest() { + s.filter = newLoginFilterWithCfg(testAdvancedCfg()) +} + +func TestLoginFilterTestSuite(t *testing.T) { + suite.Run(t, new(LoginFilterTestSuite)) +} + +func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.True(s.filter.allowLogin(pubKey, meta)) + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(1, s.filter.logged[pubKey].sessionCounter) +} + +func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + limit := s.filter.cfg.reconnLimitForBan + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + + s.False(s.filter.allowLogin(pubKey, meta)) + s.Require().Contains(s.filter.logged, pubKey) + s.True(s.filter.logged[pubKey].isBanned) +} + +func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + limit := s.filter.cfg.reconnLimitForBan + baseBan := s.filter.cfg.baseBlockDuration + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + s.Require().Contains(s.filter.logged, pubKey) + s.True(s.filter.logged[pubKey].isBanned) + s.Equal(1, s.filter.logged[pubKey].banLevel) + firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond)) + + s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second) + s.filter.logged[pubKey].isBanned = false + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + s.True(s.filter.logged[pubKey].isBanned) + s.Equal(2, s.filter.logged[pubKey].banLevel) + secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1)) + s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond)) +} + +func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.filter.logged[pubKey] = &peerState{ + isBanned: true, + banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)), + } + + s.True(s.filter.allowLogin(pubKey, meta)) + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.False(s.filter.logged[pubKey].isBanned) +} + +func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.filter.logged[pubKey] = &peerState{ + currentHash: meta, + banLevel: 3, + lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration), + } + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(0, s.filter.logged[pubKey].banLevel) +} + +func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() { + pubKey := "PUB_KEY_A" + limit := s.filter.cfg.metaChangeLimit + + for i := range limit { + s.filter.addLogin(pubKey, uint64(i+1)) + } + + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter) + + isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1)) + + s.False(isAllowed, "should block new meta hash after limit is reached") +} + +func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() { + pubKey := "PUB_KEY_A" + meta1 := uint64(1) + meta2 := uint64(2) + meta3 := uint64(3) + + s.filter.addLogin(pubKey, meta1) + s.filter.addLogin(pubKey, meta2) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter) + s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window") + + s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second)) + + s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires") + + s.filter.addLogin(pubKey, meta3) + s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset") +} + +func BenchmarkHashingMethods(b *testing.B) { + meta := nbpeer.PeerSystemMeta{ + WtVersion: "1.25.1", + OSVersion: "Ubuntu 22.04.3 LTS", + KernelVersion: "5.15.0-76-generic", + Hostname: "prod-server-database-01", + SystemSerialNumber: "PC-1234567890", + NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}}, + } + pubip := "8.8.8.8" + + var resultString string + var resultUint uint64 + + b.Run("BuilderString", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultString = builderString(meta, pubip) + } + }) + + b.Run("FnvHashToString", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultString = fnvHashToString(meta, pubip) + } + }) + + b.Run("FnvHashToUint64 - used", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultUint = metaHash(meta, pubip) + } + }) + + _ = resultString + _ = resultUint +} + +func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string { + h := fnv.New64a() + + if len(meta.NetworkAddresses) != 0 { + for _, na := range meta.NetworkAddresses { + h.Write([]byte(na.Mac)) + } + } + + h.Write([]byte(meta.WtVersion)) + h.Write([]byte(meta.OSVersion)) + h.Write([]byte(meta.KernelVersion)) + h.Write([]byte(meta.Hostname)) + h.Write([]byte(meta.SystemSerialNumber)) + h.Write([]byte(pubip)) + + return strconv.FormatUint(h.Sum64(), 16) +} + +func builderString(meta nbpeer.PeerSystemMeta, pubip string) string { + mac := getMacAddress(meta.NetworkAddresses) + estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + + len(pubip) + len(mac) + 6 + + var b strings.Builder + b.Grow(estimatedSize) + + b.WriteString(meta.WtVersion) + b.WriteByte('|') + b.WriteString(meta.OSVersion) + b.WriteByte('|') + b.WriteString(meta.KernelVersion) + b.WriteByte('|') + b.WriteString(meta.Hostname) + b.WriteByte('|') + b.WriteString(meta.SystemSerialNumber) + b.WriteByte('|') + b.WriteString(pubip) + + return b.String() +} + +func getMacAddress(nas []nbpeer.NetworkAddress) string { + if len(nas) == 0 { + return "" + } + macs := make([]string, 0, len(nas)) + for _, na := range nas { + macs = append(macs, na.Mac) + } + return strings.Join(macs, "/") +} + +func BenchmarkLoginFilter_ParallelLoad(b *testing.B) { + filter := newLoginFilterWithCfg(testAdvancedCfg()) + numKeys := 100000 + pubKeys := make([]string, numKeys) + for i := range numKeys { + pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for pb.Next() { + key := pubKeys[r.Intn(numKeys)] + meta := r.Uint64() + + if filter.allowLogin(key, meta) { + filter.addLogin(key, meta) + } + } + }) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 6f9c2696f..003385eb5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -95,6 +95,8 @@ type MockAccountManager struct { LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error + ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) + RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() account.ExternalCacheManager @@ -121,8 +123,10 @@ type MockAccountManager struct { GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -605,6 +609,20 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") } +func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + if am.ApproveUserFunc != nil { + return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented") +} + +func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + if am.RejectUserFunc != nil { + return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID) + } + return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented") +} + // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if am.GetNameServerGroupFunc != nil { @@ -953,3 +971,10 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } + +func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { + if am.AllowSyncFunc != nil { + return am.AllowSyncFunc(key, hash) + } + return true +} diff --git a/management/server/peer.go b/management/server/peer.go index 8af71cbd2..81f037499 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -368,10 +368,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { - return fmt.Errorf("failed to remove peer from groups: %w", err) - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -493,6 +489,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found") } + if user.PendingApproval { + return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") + } groupsToAdd = user.AutoGroups opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c4822aa62..31c309430 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -26,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" @@ -989,19 +990,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -1609,7 +1605,6 @@ func Test_LoginPeer(t *testing.T) { testCases := []struct { name string setupKey string - wireGuardPubKey string expectExtraDNSLabelsMismatch bool extraDNSLabels []string expectLoginError bool @@ -2388,3 +2383,186 @@ func TestBufferUpdateAccountPeers(t *testing.T) { assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) } + +func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Try to add peer with pending approval user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + peer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + require.Error(t, err) + assert.Contains(t, err.Error(), "user pending approval cannot add peers") +} + +func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create regular user (not pending approval) + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + // Try to add peer with regular user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + peer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + require.NoError(t, err, "Regular user should be able to add peers") +} + +func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Create a peer using AddPeer method for the pending user (simulate existing peer) + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + // Set the user to not be pending initially so peer can be added + pendingUser.Blocked = false + pendingUser.PendingApproval = false + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Add peer using regular flow + newPeer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + WtVersion: "0.28.0", + }, + } + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + require.NoError(t, err) + + // Now set the user back to pending approval after peer was created + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Try to login with pending approval user + login := types.PeerLogin{ + WireGuardPubKey: existingPeer.Key, + UserID: pendingUser.Id, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.LoginPeer(context.Background(), login) + require.Error(t, err) + e, ok := status.FromError(err) + require.True(t, ok, "error is not a gRPC status error") + assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code") +} + +func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create regular user (not pending approval) + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + // Add peer using regular flow for the regular user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + WtVersion: "0.28.0", + }, + } + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + require.NoError(t, err) + + // Try to login with regular user + login := types.PeerLogin{ + WireGuardPubKey: existingPeer.Key, + UserID: regularUser.Id, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.LoginPeer(context.Background(), login) + require.NoError(t, err, "Regular user should be able to login peers") +} diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go index 50e36a880..cb135f4ac 100644 --- a/management/server/peers/manager.go +++ b/management/server/peers/manager.go @@ -18,6 +18,7 @@ type Manager interface { GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) GetPeerAccountID(ctx context.Context, peerID string) (string, error) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) } type managerImpl struct { @@ -61,3 +62,7 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) } + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} diff --git a/management/server/peers/manager_mock.go b/management/server/peers/manager_mock.go index b247a1752..994f8346b 100644 --- a/management/server/peers/manager_mock.go +++ b/management/server/peers/manager_mock.go @@ -79,3 +79,18 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) } + +// GetPeersByGroupIDs mocks base method. +func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs) + ret0, _ := ret[0].([]*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs. +func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) +} diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 0ab244243..891fa59bb 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -54,10 +54,14 @@ func (m *managerImpl) ValidateUserPermissions( return false, status.NewUserNotFoundError(userID) } - if user.IsBlocked() { + if user.IsBlocked() && !user.PendingApproval { return false, status.NewUserBlockedError() } + if user.IsBlocked() && user.PendingApproval { + return false, status.NewUserPendingApprovalError() + } + if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil { return false, err } diff --git a/management/server/policy.go b/management/server/policy.go index 312fd53b2..3adee6397 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -167,10 +167,22 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } + + // TODO: Refactor to support multiple rules per policy + existingRuleIDs := make(map[string]bool) + for _, rule := range existingPolicy.Rules { + existingRuleIDs[rule.ID] = true + } + + for _, rule := range policy.Rules { + if rule.ID != "" && !existingRuleIDs[rule.ID] { + return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + } + } } else { policy.ID = xid.New().String() policy.AccountID = accountID diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 6ef93f0d1..027938320 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -914,7 +914,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { var account types.Account - result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account) + result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account) if result.Error != nil { return "", status.NewGetAccountFromStoreError(result.Error) } @@ -1399,7 +1399,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI PeerID: peerID, } - err := s.db.WithContext(ctx).Clauses(clause.OnConflict{ + err := s.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, DoNothing: true, }).Create(peer).Error @@ -1414,7 +1414,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI // RemovePeerFromGroup removes a peer from a group func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { - err := s.db.WithContext(ctx). + err := s.db. Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error if err != nil { @@ -1427,7 +1427,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group // RemovePeerFromAllGroups removes a peer from all groups func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { - err := s.db.WithContext(ctx). + err := s.db. Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error if err != nil { @@ -2015,7 +2015,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } @@ -2706,7 +2706,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -2847,3 +2847,22 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i } return nil } + +func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { + if len(groupIDs) == 0 { + return []*nbpeer.Peer{}, nil + } + + var peers []*nbpeer.Peer + peerIDsSubquery := s.db.Model(&types.GroupPeer{}). + Select("DISTINCT peer_id"). + Where("account_id = ? AND group_id IN ?", accountID, groupIDs) + + result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by group IDs") + } + + return peers, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 935b0a595..d40c4664c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3607,3 +3607,113 @@ func intToIPv4(n uint32) net.IP { binary.BigEndian.PutUint32(ip, n) return ip } + +func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group1ID := "test-group-1" + group2ID := "test-group-2" + emptyGroupID := "empty-group" + + peer1 := "cfefqs706sqkneg59g4g" + peer2 := "cfeg6sf06sqkneg59g50" + + tests := []struct { + name string + groupIDs []string + expectedPeers []string + expectedCount int + }{ + { + name: "retrieve peers from single group with multiple peers", + groupIDs: []string{group1ID}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from single group with one peer", + groupIDs: []string{group2ID}, + expectedPeers: []string{peer1}, + expectedCount: 1, + }, + { + name: "retrieve peers from multiple groups (with overlap)", + groupIDs: []string{group1ID, group2ID}, + expectedPeers: []string{peer1, peer2}, // should deduplicate + expectedCount: 2, + }, + { + name: "retrieve peers from existing 'All' group", + groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from empty group", + groupIDs: []string{emptyGroupID}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "retrieve peers from non-existing group", + groupIDs: []string{"non-existing-group"}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "mix of existing and non-existing groups", + groupIDs: []string{group1ID, "non-existing-group"}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + + groups := []*types.Group{ + { + ID: group1ID, + AccountID: accountID, + }, + { + ID: group2ID, + AccountID: accountID, + }, + } + require.NoError(t, store.CreateGroups(ctx, accountID, groups)) + + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID)) + + peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + + if tt.expectedCount > 0 { + actualPeerIDs := make([]string, len(peers)) + for i, peer := range peers { + actualPeerIDs[i] = peer.ID + } + assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs) + + // Verify all returned peers belong to the correct account + for _, peer := range peers { + assert.Equal(t, accountID, peer.AccountID) + } + } + }) + } +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 545549410..3c9d896b0 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -136,6 +136,7 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index ac6ff2ea8..d4301802f 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -4,20 +4,28 @@ import ( "context" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) +const AccountIDLabel = "account_id" +const HighLatencyThreshold = time.Second * 7 + // GRPCMetrics are gRPC server metrics type GRPCMetrics struct { - meter metric.Meter - syncRequestsCounter metric.Int64Counter - loginRequestsCounter metric.Int64Counter - getKeyRequestsCounter metric.Int64Counter - activeStreamsGauge metric.Int64ObservableGauge - syncRequestDuration metric.Int64Histogram - loginRequestDuration metric.Int64Histogram - channelQueueLength metric.Int64Histogram - ctx context.Context + meter metric.Meter + syncRequestsCounter metric.Int64Counter + syncRequestsBlockedCounter metric.Int64Counter + syncRequestHighLatencyCounter metric.Int64Counter + loginRequestsCounter metric.Int64Counter + loginRequestsBlockedCounter metric.Int64Counter + loginRequestHighLatencyCounter metric.Int64Counter + getKeyRequestsCounter metric.Int64Counter + activeStreamsGauge metric.Int64ObservableGauge + syncRequestDuration metric.Int64Histogram + loginRequestDuration metric.Int64Histogram + channelQueueLength metric.Int64Histogram + ctx context.Context } // NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server @@ -30,6 +38,22 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of sync gRPC requests from blocked peers"), + ) + if err != nil { + return nil, err + } + + syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), + ) + if err != nil { + return nil, err + } + loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -38,6 +62,22 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of login gRPC requests from blocked peers"), + ) + if err != nil { + return nil, err + } + + loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"), + ) + if err != nil { + return nil, err + } + getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"), @@ -83,15 +123,19 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro } return &GRPCMetrics{ - meter: meter, - syncRequestsCounter: syncRequestsCounter, - loginRequestsCounter: loginRequestsCounter, - getKeyRequestsCounter: getKeyRequestsCounter, - activeStreamsGauge: activeStreamsGauge, - syncRequestDuration: syncRequestDuration, - loginRequestDuration: loginRequestDuration, - channelQueueLength: channelQueue, - ctx: ctx, + meter: meter, + syncRequestsCounter: syncRequestsCounter, + syncRequestsBlockedCounter: syncRequestsBlockedCounter, + syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, + loginRequestsCounter: loginRequestsCounter, + loginRequestsBlockedCounter: loginRequestsBlockedCounter, + loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, + getKeyRequestsCounter: getKeyRequestsCounter, + activeStreamsGauge: activeStreamsGauge, + syncRequestDuration: syncRequestDuration, + loginRequestDuration: loginRequestDuration, + channelQueueLength: channelQueue, + ctx: ctx, }, err } @@ -100,6 +144,11 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() { grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1) } +// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers +func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() { + grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) +} + // CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() { grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1) @@ -110,14 +159,25 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() { grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1) } +// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers +func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() { + grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) +} + // CountLoginRequestDuration counts the duration of the login gRPC requests -func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) { +func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + if duration > HighLatencyThreshold { + grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + } } // CountSyncRequestDuration counts the duration of the sync gRPC requests -func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) { +func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + if duration > HighLatencyThreshold { + grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + } } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. diff --git a/management/server/types/account.go b/management/server/types/account.go index 9ac2568a0..ca075b9f6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -302,7 +302,11 @@ func (a *Account) GetPeerNetworkMap( var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -1651,3 +1655,24 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } + +// filterZoneRecordsForPeers filters DNS records to only include peers to connect. +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) + peerIPs := make(map[string]struct{}) + + // Add peer's own IP to include its own DNS records + peerIPs[peer.IP.String()] = struct{}{} + + for _, peerToConnect := range peersToConnect { + peerIPs[peerToConnect.IP.String()] = struct{}{} + } + + for _, record := range customZone.Records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f8ab1d627..cd221b590 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,14 +2,17 @@ package types import ( "context" + "fmt" "net" "net/netip" "slices" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } + +func Test_FilterZoneRecordsForPeers(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + customZone nbdns.CustomZone + peersToConnect []*nbpeer.Peer + expectedRecords []nbdns.SimpleRecord + }{ + { + name: "empty peers to connect", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + { + name: "multiple peers multiple records match", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for i := 1; i <= 100; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + peersToConnect: func() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + peers = append(peers, &nbpeer.Peer{ + ID: fmt.Sprintf("peer%d", i), + IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + }) + } + return peers + }(), + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + { + name: "peers with multiple DNS labels", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, + {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + assert.Equal(t, len(tt.expectedRecords), len(result)) + assert.ElementsMatch(t, tt.expectedRecords, result) + }) + } +} diff --git a/management/server/types/network.go b/management/server/types/network.go index f072a4294..ffc019565 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -12,11 +12,11 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" ) const ( diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 56c33da3b..b4afb2f5e 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -83,6 +83,9 @@ type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + // UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator + UserApprovalRequired bool + // IntegratedValidator is the string enum for the integrated validator type IntegratedValidator string // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations @@ -99,6 +102,7 @@ type ExtraSettings struct { func (e *ExtraSettings) Copy() *ExtraSettings { return &ExtraSettings{ PeerApprovalEnabled: e.PeerApprovalEnabled, + UserApprovalRequired: e.UserApprovalRequired, IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups), IntegratedValidator: e.IntegratedValidator, FlowEnabled: e.FlowEnabled, diff --git a/management/server/types/user.go b/management/server/types/user.go index 783fe14da..beb3586df 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -64,6 +64,7 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` + PendingApproval bool `json:"pending_approval"` IntegrationReference integration_reference.IntegrationReference `json:"-"` } @@ -84,6 +85,8 @@ type User struct { PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool + // PendingApproval indicates whether the user requires approval before being activated + PendingApproval bool // LastLogin is the last time the user logged in to IdP LastLogin *time.Time // CreatedAt records the time the user was created @@ -141,16 +144,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { if userData == nil { return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + PendingApproval: u.PendingApproval, }, nil } if userData.ID != u.Id { @@ -163,16 +167,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { } return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + PendingApproval: u.PendingApproval, }, nil } @@ -194,6 +199,7 @@ func (u *User) Copy() *User { ServiceUserName: u.ServiceUserName, PATs: pats, Blocked: u.Blocked, + PendingApproval: u.PendingApproval, LastLogin: u.LastLogin, CreatedAt: u.CreatedAt, Issued: u.Issued, diff --git a/management/server/user.go b/management/server/user.go index 4596ee95b..d40d33c6a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -519,33 +519,46 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUser = result } - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - for _, update := range updates { - if update == nil { - return status.Errorf(status.InvalidArgument, "provided user update is nil") - } + var globalErr error + for _, update := range updates { + if update == nil { + return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - usersToSave = append(usersToSave, updatedUser) - addUserEvents = append(addUserEvents, userEvents...) - peersToExpire = append(peersToExpire, userPeersToExpire...) if userHadPeers { updateAccountPeers = true } + + err = transaction.SaveUser(ctx, updatedUser) + if err != nil { + return fmt.Errorf("failed to save updated user %s: %w", update.Id, err) + } + + usersToSave = append(usersToSave, updatedUser) + addUserEvents = append(addUserEvents, userEvents...) + peersToExpire = append(peersToExpire, userPeersToExpire...) + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err) + if len(updates) == 1 { + return nil, err + } + globalErr = errors.Join(globalErr, err) + // continue when updating multiple users } - return transaction.SaveUsers(ctx, usersToSave) - }) - if err != nil { - return nil, err } - var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates)) + var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave)) userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID) if err != nil { @@ -578,7 +591,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, am.UpdateAccountPeers(ctx, accountID) } - return updatedUsersInfo, nil + return updatedUsersInfo, globalErr } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. @@ -643,7 +656,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } transferredOwnerRole = result - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id) if err != nil { return false, nil, nil, nil, err } @@ -929,6 +942,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + if peer.UserID == "" { + // we do not want to expire peers that are added via setup key + continue + } + if peer.Status.LoginExpired { continue } @@ -947,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service + log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1194,3 +1213,77 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut return userWithPermissions, nil } + +// ApproveUser approves a user that is pending approval +func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.NewUserNotFoundError(targetUserID) + } + + if !user.PendingApproval { + return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) + } + + user.Blocked = false + user.PendingApproval = false + + err = am.Store.SaveUser(ctx, user) + if err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil) + + userInfo, err := am.getUserInfo(ctx, user, accountID) + if err != nil { + return nil, err + } + + return userInfo, nil +} + +// RejectUser rejects a user that is pending approval by deleting them +func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotFoundError(targetUserID) + } + + if !user.PendingApproval { + return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) + } + + err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID) + if err != nil { + return err + } + + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil) + + return nil +} diff --git a/management/server/user_test.go b/management/server/user_test.go index 8ab0c1565..9638559f9 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1746,3 +1746,117 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { return permissions } + +func TestApproveUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account with admin and pending approval user + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create admin user + adminUser := types.NewAdminUser("admin-user") + adminUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), adminUser) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Test successful approval + approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.NoError(t, err) + assert.False(t, approvedUser.IsBlocked) + assert.False(t, approvedUser.PendingApproval) + + // Verify user is updated in store + updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) + require.NoError(t, err) + assert.False(t, updatedUser.Blocked) + assert.False(t, updatedUser.PendingApproval) + + // Test approval of non-pending user should fail + _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not pending approval") + + // Test approval by non-admin should fail + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2.AccountID = account.Id + pendingUser2.Blocked = true + pendingUser2.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser2) + require.NoError(t, err) + + _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) + require.Error(t, err) +} + +func TestRejectUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account with admin and pending approval user + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create admin user + adminUser := types.NewAdminUser("admin-user") + adminUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), adminUser) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Test successful rejection + err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.NoError(t, err) + + // Verify user is deleted from store + _, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) + require.Error(t, err) + + // Test rejection of non-pending user should fail + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not pending approval") + + // Test rejection by non-admin should fail + pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2.AccountID = account.Id + pendingUser2.Blocked = true + pendingUser2.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser2) + require.NoError(t, err) + + err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) + require.Error(t, err) +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 332127660..12219e29b 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { connRemoteAddr := remoteAddr(r) - wsConn, err := websocket.Accept(w, r, nil) + + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) if err != nil { log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err) return diff --git a/release_files/install.sh b/release_files/install.sh index 856d332cb..5d5349ec4 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -130,36 +130,6 @@ repo_gpgcheck=1 EOF } -install_aur_package() { - INSTALL_PKGS="git base-devel go" - REMOVE_PKGS="" - - # Check if dependencies are installed - for PKG in $INSTALL_PKGS; do - if ! pacman -Q "$PKG" > /dev/null 2>&1; then - # Install missing package(s) - ${SUDO} pacman -S "$PKG" --noconfirm - - # Add installed package for clean up later - REMOVE_PKGS="$REMOVE_PKGS $PKG" - fi - done - - # Build package from AUR - cd /tmp && git clone https://aur.archlinux.org/netbird.git - cd netbird && makepkg -sri --noconfirm - - if ! $SKIP_UI_APP; then - cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git - cd netbird-ui && makepkg -sri --noconfirm - fi - - if [ -n "$REMOVE_PKGS" ]; then - # Clean up the installed packages - ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm - fi -} - prepare_tun_module() { # Create the necessary file structure for /dev/net/tun if [ ! -c /dev/net/tun ]; then @@ -276,12 +246,9 @@ install_netbird() { if ! $SKIP_UI_APP; then ${SUDO} rpm-ostree -y install netbird-ui fi - ;; - pacman) - ${SUDO} pacman -Syy - install_aur_package - # in-line with the docs at https://wiki.archlinux.org/title/Netbird - ${SUDO} systemctl enable --now netbird@main.service + # ensure the service is started after install + ${SUDO} netbird service install || true + ${SUDO} netbird service start || true ;; pkg) # Check if the package is already installed @@ -458,11 +425,7 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" - elif [ -x "$(command -v pacman)" ]; then - PACKAGE_MANAGER="pacman" - echo "The installation will be performed using pacman package manager" fi - else echo "Unable to determine OS type from /etc/os-release" exit 1 diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 3037b44bb..becc10ded 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -9,34 +9,30 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - - "github.com/netbirdio/management-integrations/integrations" - - "github.com/netbirdio/netbird/encryption" - mgmt "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/mock_server" - mgmtProto "github.com/netbirdio/netbird/shared/management/proto" - + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmt "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -72,13 +68,31 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + permissionsManagerMock := permissions.NewMockManager(ctrl) + permissionsManagerMock. + EXPECT(). + ValidateUserPermissions( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + ). + Return(true, nil). + AnyTimes() + + peersManger := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager. EXPECT(). @@ -95,19 +109,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - permissionsManagerMock := permissions.NewMockManager(ctrl) - permissionsManagerMock. - EXPECT(). - ValidateUserPermissions( - gomock.Any(), - gomock.Any(), - gomock.Any(), - gomock.Any(), - gomock.Any(), - ). - Return(true, nil). - AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index dc26253e9..f30e965be 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -17,11 +17,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second @@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go index 56c859652..54a0290d0 100644 --- a/shared/management/client/rest/client_test.go +++ b/shared/management/client/rest/client_test.go @@ -8,8 +8,8 @@ import ( "net/http/httptest" "testing" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/client/rest" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) func withMockClient(callback func(*rest.Client, *http.ServeMux)) { @@ -26,7 +26,7 @@ func ptr[T any, PT *T](x T) PT { func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) { t.Helper() - handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) + handler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) server := httptest.NewServer(handler) defer server.Close() c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp") diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index cf4b6d625..9a531b2ff 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -158,6 +158,10 @@ components: description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. type: boolean example: true + user_approval_required: + description: Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. + type: boolean + example: false network_traffic_logs_enabled: description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. type: boolean @@ -174,6 +178,7 @@ components: example: true required: - peer_approval_enabled + - user_approval_required - network_traffic_logs_enabled - network_traffic_logs_groups - network_traffic_packet_counter_enabled @@ -235,6 +240,10 @@ components: description: Is true if this user is blocked. Blocked users can't use the system type: boolean example: false + pending_approval: + description: Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. + type: boolean + example: false issued: description: How user was issued by API or Integration type: string @@ -249,6 +258,7 @@ components: - auto_groups - status - is_blocked + - pending_approval UserPermissions: type: object properties: @@ -2544,6 +2554,63 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/{userId}/approve: + post: + summary: Approve user + description: Approve a user that is pending approval + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The unique identifier of a user + responses: + '200': + description: Returns the approved user + content: + application/json: + schema: + "$ref": "#/components/schemas/User" + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/{userId}/reject: + delete: + summary: Reject user + description: Reject a user that is pending approval by removing them from the account + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The unique identifier of a user + responses: + '200': + description: User rejected successfully + content: {} + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/users/current: get: summary: Retrieve current user diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index cffc9e735..28b89633c 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -268,6 +268,9 @@ type AccountExtraSettings struct { // PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. PeerApprovalEnabled bool `json:"peer_approval_enabled"` + + // UserApprovalRequired Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. + UserApprovalRequired bool `json:"user_approval_required"` } // AccountOnboarding defines model for AccountOnboarding. @@ -1015,8 +1018,6 @@ type OSVersionCheck struct { // Peer defines model for Peer. type Peer struct { - // CreatedAt Peer creation date (UTC) - CreatedAt time.Time `json:"created_at"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval ApprovalRequired bool `json:"approval_required"` @@ -1032,6 +1033,9 @@ type Peer struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country CountryCode CountryCode `json:"country_code"` + // CreatedAt Peer creation date (UTC) + CreatedAt time.Time `json:"created_at"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1098,8 +1102,6 @@ type Peer struct { // PeerBatch defines model for PeerBatch. type PeerBatch struct { - // CreatedAt Peer creation date (UTC) - CreatedAt time.Time `json:"created_at"` // AccessiblePeersCount Number of accessible peers AccessiblePeersCount int `json:"accessible_peers_count"` @@ -1118,6 +1120,9 @@ type PeerBatch struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country CountryCode CountryCode `json:"country_code"` + // CreatedAt Peer creation date (UTC) + CreatedAt time.Time `json:"created_at"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1774,8 +1779,11 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` - Permissions *UserPermissions `json:"permissions,omitempty"` + Name string `json:"name"` + + // PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. + PendingApproval bool `json:"pending_approval"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 7660174d6..1e914babb 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -42,7 +42,10 @@ const ( // Type is a type of the Error type Type int32 -var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found") +var ( + ErrExtraSettingsNotFound = errors.New("extra settings not found") + ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in") +) // Error is an internal error type Error struct { @@ -110,6 +113,11 @@ func NewUserBlockedError() error { return Errorf(PermissionDenied, "user is blocked") } +// NewUserPendingApprovalError creates a new Error with PermissionDenied type for a blocked user pending approval +func NewUserPendingApprovalError() error { + return Errorf(PermissionDenied, "user is pending approval") +} + // NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer func NewPeerNotRegisteredError() error { return Errorf(Unauthenticated, "peer is not registered") diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index b496f6a9b..967e18d79 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" quictls "github.com/netbirdio/netbird/shared/relay/tls" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 109651f5d..ef6bd6b3c 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index a40343fb1..6220e7f6b 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -78,9 +78,10 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin tokenStore: tokenStore, mtu: mtu, serverPicker: &ServerPicker{ - TokenStore: tokenStore, - PeerID: peerID, - MTU: mtu, + TokenStore: tokenStore, + PeerID: peerID, + MTU: mtu, + ConnectionTimeout: defaultConnectionTimeout, }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go index b6c7b5e8a..39d0ba072 100644 --- a/shared/relay/client/picker.go +++ b/shared/relay/client/picker.go @@ -13,11 +13,8 @@ import ( ) const ( - maxConcurrentServers = 7 -) - -var ( - connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 + defaultConnectionTimeout = 30 * time.Second ) type connResult struct { @@ -27,14 +24,15 @@ type connResult struct { } type ServerPicker struct { - TokenStore *auth.TokenStore - ServerURLs atomic.Value - PeerID string - MTU uint16 + TokenStore *auth.TokenStore + ServerURLs atomic.Value + PeerID string + MTU uint16 + ConnectionTimeout time.Duration } func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { - ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout) defer cancel() totalServers := len(sp.ServerURLs.Load().([]string)) diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go index 28167c5ce..fb3fa7375 100644 --- a/shared/relay/client/picker_test.go +++ b/shared/relay/client/picker_test.go @@ -8,15 +8,15 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { - connectionTimeout = 5 * time.Second - + timeout := 5 * time.Second sp := ServerPicker{ - TokenStore: nil, - PeerID: "test", + TokenStore: nil, + PeerID: "test", + ConnectionTimeout: timeout, } sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) - ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) + ctx, cancel := context.WithTimeout(context.Background(), timeout+1) defer cancel() go func() { diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go new file mode 100644 index 000000000..2b584c195 --- /dev/null +++ b/shared/relay/healthcheck/env.go @@ -0,0 +1,24 @@ +package healthcheck + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" +) + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go new file mode 100644 index 000000000..2e14bb8bf --- /dev/null +++ b/shared/relay/healthcheck/env_test.go @@ -0,0 +1,36 @@ +package healthcheck + +import ( + "os" + "testing" +) + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go index b3503d5db..90f795bbe 100644 --- a/shared/relay/healthcheck/receiver.go +++ b/shared/relay/healthcheck/receiver.go @@ -7,10 +7,15 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - heartbeatTimeout = healthCheckInterval + 10*time.Second +const ( + defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second ) +type ReceiverOptions struct { + HeartbeatTimeout time.Duration + AttemptThreshold int +} + // Receiver is a healthcheck receiver // It will listen for heartbeat and check if the heartbeat is not received in a certain time // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work @@ -27,6 +32,23 @@ type Receiver struct { // NewReceiver creates a new healthcheck receiver and start the timer in the background func NewReceiver(log *log.Entry) *Receiver { + opts := ReceiverOptions{ + HeartbeatTimeout: defaultHeartbeatTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewReceiverWithOpts(log, opts) +} + +func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { + heartbeatTimeout := opts.HeartbeatTimeout + if heartbeatTimeout <= 0 { + heartbeatTimeout = defaultHeartbeatTimeout + } + attemptThreshold := opts.AttemptThreshold + if attemptThreshold <= 0 { + attemptThreshold = defaultAttemptThreshold + } + ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ @@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver { ctx: ctx, ctxCancel: ctxCancel, heartbeat: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + attemptThreshold: attemptThreshold, } - go r.waitForHealthcheck() + go r.waitForHealthcheck(heartbeatTimeout) return r } @@ -55,7 +77,7 @@ func (r *Receiver) Stop() { r.ctxCancel() } -func (r *Receiver) waitForHealthcheck() { +func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) { ticker := time.NewTicker(heartbeatTimeout) defer ticker.Stop() defer r.ctxCancel() diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go index 2794159f6..b20cc5124 100644 --- a/shared/relay/healthcheck/receiver_test.go +++ b/shared/relay/healthcheck/receiver_test.go @@ -2,31 +2,18 @@ package healthcheck import ( "context" - "fmt" - "os" - "sync" "testing" "time" log "github.com/sirupsen/logrus" ) -// Mutex to protect global variable access in tests -var testMutex sync.Mutex - func TestNewReceiver(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 5 * time.Second - testMutex.Unlock() - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 5 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) { } func TestNewReceiverNotReceive(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 1 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 1 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) { } func TestNewReceiverAck(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 2 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 2 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() r.Heartbeat() @@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - testMutex.Lock() - originalInterval := healthCheckInterval - originalTimeout := heartbeatTimeout - healthCheckInterval = 1 * time.Second - heartbeatTimeout = healthCheckInterval + 500*time.Millisecond - testMutex.Unlock() + healthCheckInterval := 1 * time.Second - defer func() { - testMutex.Lock() - healthCheckInterval = originalInterval - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := ReceiverOptions{ + HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond, + AttemptThreshold: tc.threshold, + } - receiver := NewReceiver(log.WithField("test_name", tc.name)) + receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts) - testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval if tc.resetCounterOnce { receiver.Heartbeat() - t.Logf("reset counter once") } select { @@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { } t.Fatalf("should have timed out before %s", testTimeout) } - }) } } diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go index 57b3015ec..771e94206 100644 --- a/shared/relay/healthcheck/sender.go +++ b/shared/relay/healthcheck/sender.go @@ -2,52 +2,76 @@ package healthcheck import ( "context" - "os" - "strconv" "time" log "github.com/sirupsen/logrus" ) const ( - defaultAttemptThreshold = 1 - defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" + defaultAttemptThreshold = 1 + + defaultHealthCheckInterval = 25 * time.Second + defaultHealthCheckTimeout = 20 * time.Second ) -var ( - healthCheckInterval = 25 * time.Second - healthCheckTimeout = 20 * time.Second -) +type SenderOptions struct { + HealthCheckInterval time.Duration + HealthCheckTimeout time.Duration + AttemptThreshold int +} // Sender is a healthcheck sender // It will send healthcheck signal to the receiver // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { - log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} + log *log.Entry + healthCheckInterval time.Duration + timeout time.Duration + ack chan struct{} alive bool attemptThreshold int } -// NewSender creates a new healthcheck sender -func NewSender(log *log.Entry) *Sender { +func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender { + if opts.HealthCheckInterval <= 0 { + opts.HealthCheckInterval = defaultHealthCheckInterval + } + if opts.HealthCheckTimeout <= 0 { + opts.HealthCheckTimeout = defaultHealthCheckTimeout + } + if opts.AttemptThreshold <= 0 { + opts.AttemptThreshold = defaultAttemptThreshold + } hc := &Sender{ - log: log, - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - ack: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + log: log, + healthCheckInterval: opts.HealthCheckInterval, + timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout, + ack: make(chan struct{}, 1), + attemptThreshold: opts.AttemptThreshold, } return hc } +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { + opts := SenderOptions{ + HealthCheckInterval: defaultHealthCheckInterval, + HealthCheckTimeout: defaultHealthCheckTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewSenderWithOpts(log, opts) +} + // OnHCResponse sends an acknowledgment signal to the sender func (hc *Sender) OnHCResponse() { select { @@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() { } func (hc *Sender) StartHealthCheck(ctx context.Context) { - ticker := time.NewTicker(healthCheckInterval) + ticker := time.NewTicker(hc.healthCheckInterval) defer ticker.Stop() - timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + timeoutTicker := time.NewTicker(hc.timeout) defer timeoutTicker.Stop() defer close(hc.HealthCheck) @@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { } } } - -func (hc *Sender) getTimeoutTime() time.Duration { - return healthCheckInterval + healthCheckTimeout -} - -func getAttemptThresholdFromEnv() int { - if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { - threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) - if err != nil { - log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) - return defaultAttemptThreshold - } - return int(threshold) - } - return defaultAttemptThreshold -} diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go index 23446366a..122fe0f16 100644 --- a/shared/relay/healthcheck/sender_test.go +++ b/shared/relay/healthcheck/sender_test.go @@ -2,26 +2,23 @@ package healthcheck import ( "context" - "fmt" - "os" "testing" "time" log "github.com/sirupsen/logrus" ) -func TestMain(m *testing.M) { - // override the health check interval to speed up the test - healthCheckInterval = 2 * time.Second - healthCheckTimeout = 100 * time.Millisecond - code := m.Run() - os.Exit(code) -} +var ( + testOpts = SenderOptions{ + HealthCheckInterval: 2 * time.Second, + HealthCheckTimeout: 100 * time.Millisecond, + } +) func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) select { case <-hc.Timeout: - case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond): t.Fatalf("health check is not timed out") } } func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - originalInterval := healthCheckInterval - originalTimeout := healthCheckTimeout - healthCheckInterval = 1 * time.Second - healthCheckTimeout = 500 * time.Millisecond - - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := SenderOptions{ + HealthCheckInterval: 1 * time.Second, + HealthCheckTimeout: 500 * time.Millisecond, + AttemptThreshold: tc.threshold, + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sender := NewSender(log.WithField("test_name", tc.name)) + sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts) senderExit := make(chan struct{}) go func() { sender.StartHealthCheck(ctx) @@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { } }() - testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval select { case <-sender.Timeout: @@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout }) } } - -//nolint:tenv -func TestGetAttemptThresholdFromEnv(t *testing.T) { - tests := []struct { - name string - envValue string - expected int - }{ - {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, - {"Custom attempt threshold when env is set to a valid integer", "3", 3}, - {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.envValue == "" { - os.Unsetenv(defaultAttemptThresholdEnv) - } else { - os.Setenv(defaultAttemptThresholdEnv, tt.envValue) - } - - result := getAttemptThresholdFromEnv() - if result != tt.expected { - t.Fatalf("Expected %d, got %d", tt.expected, result) - } - - os.Unsetenv(defaultAttemptThresholdEnv) - }) - } -} diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 82ab678f4..5ca0c0282 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -16,10 +16,10 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index db428515b..bc2d4d1be 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -93,7 +93,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err) } var sockErr error @@ -102,7 +102,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err) } } @@ -230,10 +230,8 @@ func (s *SharedSocket) Close() error { // read start a read loop for a specific receiver and sends the packet to the packetDemux channel func (s *SharedSocket) read(receiver receiver) { - // Buffer reuse is safe: packetDemux is unbuffered, so read() blocks until - // ReadFrom() synchronously processes the packet before next iteration - buf := make([]byte, s.mtu+maxIPUDPOverhead) for { + buf := make([]byte, s.mtu+maxIPUDPOverhead) n, addr, err := receiver(s.ctx, buf, 0) select { case <-s.ctx.Done(): diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 2e89b491a..1d76fa4e4 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + // nolint:gosec _ "net/http/pprof" "strings" diff --git a/signal/peer/peer.go b/signal/peer/peer.go index f21c95a41..c9dd60fc0 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -5,10 +5,16 @@ import ( "sync" "time" + "errors" + log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/signal/metrics" +) + +var ( + ErrPeerAlreadyRegistered = errors.New("peer already registered") ) // Peer representation of a connected Peer @@ -23,15 +29,18 @@ type Peer struct { // registration time RegisteredAt time.Time + + Cancel context.CancelFunc } // NewPeer creates a new instance of a connected Peer -func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { +func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer { return &Peer{ Id: id, Stream: stream, StreamID: time.Now().UnixNano(), RegisteredAt: time.Now(), + Cancel: cancel, } } @@ -69,20 +78,24 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool { } // Register registers peer in the registry -func (registry *Registry) Register(peer *Peer) { +func (registry *Registry) Register(peer *Peer) error { start := time.Now() - registry.regMutex.Lock() - defer registry.regMutex.Unlock() - // can be that peer already exists, but it is fine (e.g. reconnect) p, loaded := registry.Peers.LoadOrStore(peer.Id, peer) if loaded { pp := p.(*Peer) - log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", - peer.Id, peer.StreamID, pp.StreamID) - registry.Peers.Store(peer.Id, peer) - return + if peer.StreamID > pp.StreamID { + log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", + peer.Id, peer.StreamID, pp.StreamID) + if swapped := registry.Peers.CompareAndSwap(peer.Id, pp, peer); !swapped { + return registry.Register(peer) + } + pp.Cancel() + log.Debugf("peer re-registered [%s]", peer.Id) + return nil + } + return ErrPeerAlreadyRegistered } log.Debugf("peer registered [%s]", peer.Id) @@ -92,22 +105,13 @@ func (registry *Registry) Register(peer *Peer) { registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) registry.metrics.Registrations.Add(context.Background(), 1) + + return nil } // Deregister Peer from the Registry (usually once it disconnects) func (registry *Registry) Deregister(peer *Peer) { - registry.regMutex.Lock() - defer registry.regMutex.Unlock() - - p, loaded := registry.Peers.LoadAndDelete(peer.Id) - if loaded { - pp := p.(*Peer) - if peer.StreamID < pp.StreamID { - registry.Peers.Store(peer.Id, p) - log.Debugf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.", - peer.Id, pp.StreamID, peer.StreamID) - return - } + if deleted := registry.Peers.CompareAndDelete(peer.Id, peer); deleted { registry.metrics.ActivePeers.Add(context.Background(), -1) log.Debugf("peer deregistered [%s]", peer.Id) registry.metrics.Deregistrations.Add(context.Background(), 1) diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go index fb85fedda..6b7976eb4 100644 --- a/signal/peer/peer_test.go +++ b/signal/peer/peer_test.go @@ -1,13 +1,18 @@ package peer import ( + "context" + "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" ) @@ -19,12 +24,16 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) peerID := "peer" - olderPeer := NewPeer(peerID, nil) - r.Register(olderPeer) + _, cancel1 := context.WithCancel(context.Background()) + olderPeer := NewPeer(peerID, nil, cancel1) + err = r.Register(olderPeer) + require.NoError(t, err) time.Sleep(time.Nanosecond) - newerPeer := NewPeer(peerID, nil) - r.Register(newerPeer) + _, cancel2 := context.WithCancel(context.Background()) + newerPeer := NewPeer(peerID, nil, cancel2) + err = r.Register(newerPeer) + require.NoError(t, err) registered, _ := r.Get(olderPeer.Id) assert.NotNil(t, registered, "peer can't be nil") @@ -59,10 +68,14 @@ func TestRegistry_Register(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - peer1 := NewPeer("test_peer_1", nil) - peer2 := NewPeer("test_peer_2", nil) - r.Register(peer1) - r.Register(peer2) + _, cancel1 := context.WithCancel(context.Background()) + peer1 := NewPeer("test_peer_1", nil, cancel1) + _, cancel2 := context.WithCancel(context.Background()) + peer2 := NewPeer("test_peer_2", nil, cancel2) + err = r.Register(peer1) + require.NoError(t, err) + err = r.Register(peer2) + require.NoError(t, err) if _, ok := r.Get("test_peer_1"); !ok { t.Errorf("expected test_peer_1 not found in the registry") @@ -78,10 +91,14 @@ func TestRegistry_Deregister(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - peer1 := NewPeer("test_peer_1", nil) - peer2 := NewPeer("test_peer_2", nil) - r.Register(peer1) - r.Register(peer2) + _, cancel1 := context.WithCancel(context.Background()) + peer1 := NewPeer("test_peer_1", nil, cancel1) + _, cancel2 := context.WithCancel(context.Background()) + peer2 := NewPeer("test_peer_2", nil, cancel2) + err = r.Register(peer1) + require.NoError(t, err) + err = r.Register(peer2) + require.NoError(t, err) r.Deregister(peer1) @@ -94,3 +111,213 @@ func TestRegistry_Deregister(t *testing.T) { } } + +func TestRegistry_MultipleRegister_Concurrency(t *testing.T) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + + numGoroutines := 1000 + + ids := make(chan int64, numGoroutines) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + peerID := "peer-concurrent" + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + _, cancel := context.WithCancel(context.Background()) + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + ids <- peer.StreamID + }(i) + } + + wg.Wait() + close(ids) + maxId := int64(0) + for id := range ids { + maxId = max(maxId, id) + } + + peer, ok := registry.Get(peerID) + require.True(t, ok, "expected peer to be registered") + require.Equal(t, maxId, peer.StreamID, "expected the highest StreamID to be registered") +} + +func Benchmark_MultipleRegister_Concurrency(b *testing.B) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(b, err) + + numGoroutines := 1000 + + var wg sync.WaitGroup + peerID := "peer-concurrent" + _, cancel := context.WithCancel(context.Background()) + b.Run("multiple-register", func(b *testing.B) { + registry := NewRegistry(metrics) + b.ResetTimer() + for j := 0; j < b.N; j++ { + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + }(i) + } + wg.Wait() + } + }) +} + +func TestRegistry_MultipleDeregister_Concurrency(t *testing.T) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + + numGoroutines := 1000 + + ids := make(chan int64, numGoroutines) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + peerID := "peer-concurrent" + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + _, cancel := context.WithCancel(context.Background()) + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + ids <- peer.StreamID + registry.Deregister(peer) + }(i) + } + + wg.Wait() + close(ids) + maxId := int64(0) + for id := range ids { + maxId = max(maxId, id) + } + + _, ok := registry.Get(peerID) + require.False(t, ok, "expected peer to be deregistered") +} + +func Benchmark_MultipleDeregister_Concurrency(b *testing.B) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(b, err) + + numGoroutines := 1000 + + var wg sync.WaitGroup + peerID := "peer-concurrent" + _, cancel := context.WithCancel(context.Background()) + b.Run("register-deregister", func(b *testing.B) { + registry := NewRegistry(metrics) + b.ResetTimer() + for j := 0; j < b.N; j++ { + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + time.Sleep(time.Nanosecond) + registry.Deregister(peer) + }(i) + } + wg.Wait() + } + }) +} + +type mockConnectStreamServer struct { + grpc.ServerStream + ctx context.Context +} + +func (m *mockConnectStreamServer) Context() context.Context { + return m.ctx +} + +func (m *mockConnectStreamServer) SendHeader(md metadata.MD) error { + return nil +} + +func (m *mockConnectStreamServer) Send(msg *proto.EncryptedMessage) error { + return nil +} + +func (m *mockConnectStreamServer) Recv() (*proto.EncryptedMessage, error) { + <-m.ctx.Done() + return nil, m.ctx.Err() +} + +func TestReconnectHandling(t *testing.T) { + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + peerID := "test-peer-reconnect" + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + stream1 := &mockConnectStreamServer{ctx: ctx1} + peer1 := NewPeer(peerID, stream1, cancel1) + + err = registry.Register(peer1) + require.NoError(t, err, "first registration should succeed") + + p, found := registry.Get(peerID) + require.True(t, found, "peer should be found in the registry") + require.Equal(t, peer1.StreamID, p.StreamID, "StreamID of registered peer should match") + + time.Sleep(time.Nanosecond) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + stream2 := &mockConnectStreamServer{ctx: ctx2} + peer2 := NewPeer(peerID, stream2, cancel2) + + err = registry.Register(peer2) + require.NoError(t, err, "reconnect registration should succeed") + + select { + case <-ctx1.Done(): + require.ErrorIs(t, ctx1.Err(), context.Canceled, "context of old stream should be canceled after successful reconnection") + case <-time.After(100 * time.Millisecond): + t.Fatal("context of old stream was not canceled after reconnection") + } + + p, found = registry.Get(peerID) + require.True(t, found) + require.Equal(t, peer2.StreamID, p.StreamID, "registered peer should have the new StreamID after reconnection") + + ctx3, cancel3 := context.WithCancel(context.Background()) + defer cancel3() + stream3 := &mockConnectStreamServer{ctx: ctx3} + stalePeer := NewPeer(peerID, stream3, cancel3) + stalePeer.StreamID = peer1.StreamID + + err = registry.Register(stalePeer) + require.ErrorIs(t, err, ErrPeerAlreadyRegistered, "reconnecting with an old StreamID should return an error") + + p, found = registry.Get(peerID) + require.True(t, found) + require.Equal(t, peer2.StreamID, p.StreamID, "active peer should still be the one with the latest StreamID") + + select { + case <-ctx2.Done(): + t.Fatal("context of the new stream should not be canceled after trying to register with an old StreamID") + default: + } +} diff --git a/signal/server/signal.go b/signal/server/signal.go index 8ae14822b..47f01edae 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -2,7 +2,9 @@ package server import ( "context" + "errors" "fmt" + "os" "time" log "github.com/sirupsen/logrus" @@ -15,9 +17,9 @@ import ( "github.com/netbirdio/signal-dispatcher/dispatcher" + "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" - "github.com/netbirdio/netbird/shared/signal/proto" ) const ( @@ -27,6 +29,8 @@ const ( labelTypeNotRegistered = "not_registered" labelTypeStream = "stream" labelTypeMessage = "message" + labelTypeTimeout = "timeout" + labelTypeDisconnected = "disconnected" labelError = "error" labelErrorMissingId = "missing_id" @@ -37,6 +41,12 @@ const ( labelRegistrationStatus = "status" labelRegistrationFound = "found" labelRegistrationNotFound = "not_found" + + sendTimeout = 10 * time.Second +) + +var ( + ErrPeerRegisteredAgain = errors.New("peer registered again") ) // Server an instance of a Signal server @@ -45,6 +55,10 @@ type Server struct { proto.UnimplementedSignalExchangeServer dispatcher *dispatcher.Dispatcher metrics *metrics.AppMetrics + + successHeader metadata.MD + + sendTimeout time.Duration } // NewServer creates a new Signal server @@ -59,10 +73,19 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating dispatcher: %v", err) } + sTimeout := sendTimeout + to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT") + if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 { + log.Trace("using custom send timeout ", parsed) + sTimeout = parsed + } + s := &Server{ - dispatcher: d, - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, + dispatcher: d, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, + successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), + sendTimeout: sTimeout, } return s, nil @@ -82,7 +105,8 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - p, err := s.RegisterPeer(stream) + ctx, cancel := context.WithCancel(context.Background()) + p, err := s.RegisterPeer(stream, cancel) if err != nil { return err } @@ -90,8 +114,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) defer s.DeregisterPeer(p) // needed to confirm that the peer has been registered so that the client can proceed - header := metadata.Pairs(proto.HeaderRegistered, "1") - err = stream.SendHeader(header) + err = stream.SendHeader(s.successHeader) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader))) return err @@ -99,27 +122,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) - <-stream.Context().Done() - log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) - return nil + select { + case <-stream.Context().Done(): + log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) + return nil + case <-ctx.Done(): + return ErrPeerRegisteredAgain + } } -func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { +func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) { log.Debugf("registering new peer") - meta, hasMeta := metadata.FromIncomingContext(stream.Context()) - if !hasMeta { - s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) - return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") - } - - id, found := meta[proto.HeaderId] - if !found { + id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId) + if id == nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId) } - p := peer.NewPeer(id[0], stream) - s.registry.Register(p) + p := peer.NewPeer(id[0], stream, cancel) + if err := s.registry.Register(p); err != nil { + return nil, err + } err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration))) @@ -131,8 +154,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) ( func (s *Server) DeregisterPeer(p *peer.Peer) { log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) - s.registry.Deregister(p) s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) + s.registry.Deregister(p) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { @@ -145,7 +168,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM if !found { s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) + log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) // todo respond to the sender? return } @@ -153,16 +176,34 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) start := time.Now() - // forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - // todo respond to the sender? - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - return - } + sendResultChan := make(chan error, 1) + go func() { + select { + case sendResultChan <- dstPeer.Stream.Send(msg): + return + case <-dstPeer.Stream.Context().Done(): + return + } + }() - // in milliseconds - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(ctx, 1) - s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) + select { + case err := <-sendResultChan: + if err != nil { + log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + return + } + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) + s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) + + case <-dstPeer.Stream.Context().Done(): + log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected))) + + case <-time.After(s.sendTimeout): + dstPeer.Cancel() // cancel the peer context to trigger deregistration + log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout))) + } } diff --git a/util/net/conn.go b/util/net/conn.go deleted file mode 100644 index 26693f841..000000000 --- a/util/net/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !ios - -package net - -import ( - "net" - - log "github.com/sirupsen/logrus" -) - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} diff --git a/util/net/dial.go b/util/net/dial.go deleted file mode 100644 index 595311492..000000000 --- a/util/net/dial.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build !ios - -package net - -import ( - "fmt" - "net" - - log "github.com/sirupsen/logrus" -) - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go deleted file mode 100644 index 1659b6220..000000000 --- a/util/net/dialer_dial.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHooks removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - log.Debugf("Dialing %s %s", network, address) - - if CustomRoutingDisabled() { - return d.Dialer.DialContext(ctx, network, address) - } - - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := callDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go deleted file mode 100644 index 8c57ebbaa..000000000 --- a/util/net/dialer_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (d *Dialer) init() { - // implemented on Linux and Android only -} diff --git a/util/net/env_generic.go b/util/net/env_generic.go deleted file mode 100644 index 6d142a838..000000000 --- a/util/net/env_generic.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !linux || android - -package net - -func Init() { - // nothing to do on non-linux -} - -func AdvancedRouting() bool { - // non-linux currently doesn't support advanced routing - return false -} diff --git a/util/net/listen.go b/util/net/listen.go deleted file mode 100644 index 3ae8a9435..000000000 --- a/util/net/listen.go +++ /dev/null @@ -1,37 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go deleted file mode 100644 index 80f6f7f1a..000000000 --- a/util/net/listener_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (l *ListenerConfig) init() { - // implemented on Linux and Android only -} diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go deleted file mode 100644 index 4060ab49a..000000000 --- a/util/net/listener_listen.go +++ /dev/null @@ -1,205 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. -type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc - listenerAddressRemoveHooksMutex sync.RWMutex - listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. -func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) -} - -// RemoveListenerHooks removes all listener hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil - - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - if CustomRoutingDisabled() { - return l.ListenConfig.ListenPacket(ctx, network, address) - } - - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -// RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *PacketConn) RemoveAddress(addr string) { - if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { - return - } - - ipStr, _, err := net.SplitHostPort(addr) - if err != nil { - log.Errorf("Error splitting IP address and port: %v", err) - return - } - - ipAddr, err := netip.ParseAddr(ipStr) - if err != nil { - log.Errorf("Error parsing IP address %s: %v", ipStr, err) - return - } - - prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) - - listenerAddressRemoveHooksMutex.RLock() - defer listenerAddressRemoveHooksMutex.RUnlock() - - for _, hook := range listenerAddressRemoveHooks { - if err := hook(c.ID, prefix); err != nil { - log.Errorf("Error executing listener address remove hook: %v", err) - } - } -} - - -// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality -func WrapPacketConn(conn net.PacketConn) *PacketConn { - return &PacketConn{ - PacketConn: conn, - ID: GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -}